diff --git a/LICENSE b/LICENSE index 9b364a4d00079..21c42e9a20fa3 100644 --- a/LICENSE +++ b/LICENSE @@ -814,6 +814,7 @@ BSD-style licenses The following components are provided under a BSD-style license. See project link for details. (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) diff --git a/assembly/pom.xml b/assembly/pom.xml index 20593e710dedb..2b4d0a990bf22 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -194,7 +194,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 3d068dd3a2739..db09fa27e51a6 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -61,7 +61,10 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) +del %LAUNCHER_OUTPUT% %SPARK_CMD% diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 67f81d33361e1..43c4288912b18 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -39,6 +39,7 @@ # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/pom.xml b/core/pom.xml index 5e89d548cd47f..2dfb00d7ecf26 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,11 @@ spark-network-shuffle_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + net.java.dev.jets3t jets3t @@ -478,7 +483,6 @@ org.codehaus.mojo exec-maven-plugin - 1.3.2 sparkr-pkg diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java new file mode 100644 index 0000000000000..38e410c5debe6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A zero-argument function that returns an R. + */ +public interface Function0 extends Serializable { + public R call() throws Exception; +} diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 68d05d5b02537..f2b024ff6cb67 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { - timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ExpireDeadHosts)) } @@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - context.reply(response) + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { - override def run(): Unit = sc.killExecutor(executorId) + override def run(): Unit = Utils.tryLogNonFatalError { + sc.killExecutor(executorId) + } }) } executorLastSeen.remove(executorId) @@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) } - timeoutCheckingThread.shutdownNow() + eventLoopThread.shutdownNow() killExecutorThread.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d65c94e410662..16072283edbe9 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -106,7 +106,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker[T: ClassTag](message: Any): T = { try { - trackerEndpoint.askWithReply[T](message) + trackerEndpoint.askWithRetry[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c1996e08756a6..a8fc90ad2050e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -211,7 +211,74 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsMs(get(key, defaultValue)) } + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -407,7 +474,13 @@ private[spark] object SparkConf extends Logging { "The spark.cache.class property is no longer being used! Specify storage levels using " + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", - "Please use spark.{driver,executor}.userClassPathFirst instead.")) + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) } @@ -432,6 +505,21 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${s.toDouble * 1000}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), "spark.rpc.numRetries" -> Seq( AlternateConfig("spark.akka.num.retries", "1.4")), "spark.rpc.retry.wait" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 65b903a55d5bd..bae951f388337 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -555,7 +555,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.executorActorSystemName, RpcAddress(host, port), ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) - Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) + Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -713,7 +713,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli RDD[(String, String)] = { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updateConf = job.getConfiguration new WholeTextFileRDD( this, @@ -759,7 +761,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli RDD[(String, PortableDataStream)] = { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updateConf = job.getConfiguration new BinaryFileRDD( this, @@ -935,7 +939,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // The call to new NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updatedConf = job.getConfiguration new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1396,6 +1402,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { + _executorAllocationManager.foreach { _ => + logWarning( + s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + + s"${rdd.id} will be lost when executors are removed.") + } persistentRdds(rdd.id) = rdd } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 959aefabd8de4..0c4d28f786edd 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{RpcUtils, Utils} /** @@ -69,6 +70,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -382,6 +384,15 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) + val executorMemoryManager: ExecutorMemoryManager = { + val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { + MemoryAllocator.UNSAFE + } else { + MemoryAllocator.HEAP + } + new ExecutorMemoryManager(allocator) + } + val envInstance = new SparkEnv( executorId, rpcEnv, @@ -398,6 +409,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a446313..d09e17dea0911 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + /** + * Returns the manager for this task's managed memory. + */ + private[spark] def taskMemoryManager(): TaskMemoryManager } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebcd..b4d572cb52313 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import scala.collection.mutable.ArrayBuffer @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl( val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 398ca41e16151..fe6320b504e15 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -105,23 +105,18 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private class JavaSourceFromString(val name: String, val code: String) + private[spark] class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + /** Creates a compiled class with the source file. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { + sourceFile: JavaSourceFromString, + classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -144,4 +139,18 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") + createCompiledClass(className, destDir, sourceFile, classpathUrls) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 23b02e60338fb..a0c9b5e63c744 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala similarity index 59% rename from core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala rename to core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b9798963bab0a..cd16f992a3c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.worker +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext @@ -23,6 +25,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslRpcHandler import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.util.Utils /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -31,8 +34,8 @@ import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler * * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) @@ -51,16 +54,58 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + start() } } + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + def stop() { - if (enabled && server != null) { + if (server != null) { server.close() server = null } } } + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = new ExternalShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index a7c89276a045e..c048b78910f38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} -import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala rename to core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index 5b22481ea8c5f..b8d3993540220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.master +package org.apache.spark.deploy import scala.collection.JavaConversions._ @@ -25,15 +25,17 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -private[deploy] object SparkCuratorUtil extends Logging { +private[spark] object SparkCuratorUtil extends Logging { private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 private val ZK_SESSION_TIMEOUT_MILLIS = 60000 private val RETRY_WAIT_MILLIS = 5000 private val MAX_RECONNECT_ATTEMPTS = 3 - def newClient(conf: SparkConf): CuratorFramework = { - val ZK_URL = conf.get("spark.deploy.zookeeper.url") + def newClient( + conf: SparkConf, + zkUrlConf: String = "spark.deploy.zookeeper.url"): CuratorFramework = { + val ZK_URL = conf.get(zkUrlConf) val zk = CuratorFrameworkFactory.newClient(ZK_URL, ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 296a0764b8baf..0d149e703aff2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.nio.file.{Path => JavaPath} import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -36,11 +37,11 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} - import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone cluster mode. @@ -114,18 +115,20 @@ object SparkSubmit { } } - /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + /** + * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + */ private def kill(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .killSubmission(args.master, args.submissionToKill) } /** * Request the status of an existing submission using the REST protocol. - * Standalone cluster mode only. + * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) } @@ -252,6 +255,7 @@ object SparkSubmit { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code @@ -294,8 +298,9 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (MESOS, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -377,15 +382,6 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Standalone cluster only - // Do not set CL arguments here because there are multiple possibilities for the main class - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), - OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, - sysProp = "spark.driver.supervise"), - // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), @@ -413,7 +409,15 @@ object SparkSubmit { OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") ) // In client mode, launch the application main class directly @@ -452,7 +456,7 @@ object SparkSubmit { // All Spark parameters are expected to be passed to the client through system properties. if (args.isStandaloneCluster) { if (args.useRest) { - childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" childArgs += (args.primaryResource, args.mainClass) } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class @@ -496,6 +500,15 @@ object SparkSubmit { } } + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) @@ -696,7 +709,9 @@ private[deploy] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -719,16 +734,37 @@ private[deploy] object SparkSubmitUtils { } } + /** Path of the local Maven cache. */ + private[spark] def m2Path: JavaPath = new File(System.getProperty("user.home"), + ".m2" + File.separator + "repository" + File.separator).toPath + /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories + * @param ivySettings The Ivy settings for this session * @return A ChainResolver used by Ivy to search for and resolve dependencies. */ - def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + def createRepoResolvers(remoteRepos: Option[String], ivySettings: IvySettings): ChainResolver = { // We need a chain resolver if we want to check multiple repositories val cr = new ChainResolver cr.setName("list") + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + localM2.setRoot(m2Path.toUri.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + + val localIvy = new IBiblioResolver + localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, + "local" + File.separator).toURI.toString) + val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", + "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.setPattern(ivyPattern) + localIvy.setName("local-ivy-cache") + cr.add(localIvy) + // the biblio resolver resolves POM declared dependencies val br: IBiblioResolver = new IBiblioResolver br.setM2compatible(true) @@ -761,8 +797,7 @@ private[deploy] object SparkSubmitUtils { /** * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath - * (will append to jars in SparkSubmit). The name of the jar is given - * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * (will append to jars in SparkSubmit). * @param artifacts Sequence of dependencies that were resolved and retrieved * @param cacheDirectory directory where jars are cached * @return a comma-delimited list of paths for the dependencies @@ -771,10 +806,9 @@ private[deploy] object SparkSubmitUtils { artifacts: Array[AnyRef], cacheDirectory: File): String = { artifacts.map { artifactInfo => - val artifactString = artifactInfo.toString - val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId cacheDirectory.getAbsolutePath + File.separator + - jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar" }.mkString(",") } @@ -842,67 +876,72 @@ private[deploy] object SparkSubmitUtils { "" } else { val sysOut = System.out - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) } else { - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") + resolveOptions.setDownload(true) } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) - } else { - resolveOptions.setDownload(true) - } - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - System.setOut(sysOut) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c896842943f2b..c621b8fc86f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -241,8 +241,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://")) { - SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { SparkSubmit.printErrorAndExit("Please specify a submission to kill.") @@ -250,9 +251,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateStatusRequestArguments(): Unit = { - if (!master.startsWith("spark://")) { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { SparkSubmit.printErrorAndExit( - "Requesting submission statuses is only supported in standalone mode!") + "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") @@ -485,6 +486,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). + | + | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a94ebf6e53750..fb2cbbcccc54b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -333,8 +333,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString) + bus.replay(logInput, logPath.toString, !appCompleted) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), @@ -343,7 +344,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + appCompleted) } finally { logInput.close() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index ff2eed6dee70a..1c21c179562ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -130,7 +130,7 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 4823fd7cac0cb..52758d6a7c4be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a285783f72000..80db6d474b5c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -26,6 +26,7 @@ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala new file mode 100644 index 0000000000000..5d4e5b899dfdc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.deploy.mesos.ui.MesosClusterUI +import org.apache.spark.deploy.rest.mesos.MesosRestServer +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.util.SignalLogger +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +/* + * A dispatcher that is responsible for managing and launching drivers, and is intended to be + * used for Mesos cluster mode. The dispatcher is a long-running process started by the user in + * the cluster independently of Spark applications. + * It contains a [[MesosRestServer]] that listens for requests to submit drivers and a + * [[MesosClusterScheduler]] that processes these requests by negotiating with the Mesos master + * for resources. + * + * A typical new driver lifecycle is the following: + * - Driver submitted via spark-submit talking to the [[MesosRestServer]] + * - [[MesosRestServer]] queues the driver request to [[MesosClusterScheduler]] + * - [[MesosClusterScheduler]] gets resource offers and launches the drivers that are in queue + * + * This dispatcher supports both Mesos fine-grain or coarse-grain mode as the mode is configurable + * per driver launched. + * This class is needed since Mesos doesn't manage frameworks, so the dispatcher acts as + * a daemon to launch drivers as Mesos frameworks upon request. The dispatcher is also started and + * stopped by sbin/start-mesos-dispatcher and sbin/stop-mesos-dispatcher respectively. + */ +private[mesos] class MesosClusterDispatcher( + args: MesosClusterDispatcherArguments, + conf: SparkConf) + extends Logging { + + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) + private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) + + private val engineFactory = recoveryMode match { + case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory + case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf) + case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode) + } + + private val scheduler = new MesosClusterScheduler(engineFactory, conf) + + private val server = new MesosRestServer(args.host, args.port, conf, scheduler) + private val webUi = new MesosClusterUI( + new SecurityManager(conf), + args.webUiPort, + conf, + publicAddress, + scheduler) + + private val shutdownLatch = new CountDownLatch(1) + + def start(): Unit = { + webUi.bind() + scheduler.frameworkUrl = webUi.activeWebUiUrl + scheduler.start() + server.start() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + def stop(): Unit = { + webUi.stop() + server.stop() + scheduler.stop() + shutdownLatch.countDown() + } +} + +private[mesos] object MesosClusterDispatcher extends Logging { + def main(args: Array[String]) { + SignalLogger.register(log) + val conf = new SparkConf + val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + conf.setMaster(dispatcherArgs.masterUrl) + conf.setAppName(dispatcherArgs.name) + dispatcherArgs.zookeeperUrl.foreach { z => + conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.mesos.deploy.zookeeper.url", z) + } + val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) + dispatcher.start() + val shutdownHook = new Thread() { + override def run() { + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + dispatcher.awaitShutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala new file mode 100644 index 0000000000000..894cb78d8591a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, Utils} + + +private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { + var host = Utils.localHostName() + var port = 7077 + var name = "Spark Cluster" + var webUiPort = 8081 + var masterUrl: String = _ + var zookeeperUrl: Option[String] = None + var propertiesFile: String = _ + + parse(args.toList) + + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + private def parse(args: List[String]): Unit = args match { + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--webui-port" | "-p") :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--zk" | "-z") :: value :: tail => + zookeeperUrl = Some(value) + parse(tail) + + case ("--master" | "-m") :: value :: tail => + if (!value.startsWith("mesos://")) { + System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + System.exit(1) + } + masterUrl = value.stripPrefix("mesos://") + parse(tail) + + case ("--name") :: value :: tail => + name = value + parse(tail) + + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => + printUsageAndExit(0) + + case Nil => { + if (masterUrl == null) { + System.err.println("--master is required") + printUsageAndExit(1) + } + } + + case _ => + printUsageAndExit(1) + } + + private def printUsageAndExit(exitCode: Int): Unit = { + System.err.println( + "Usage: MesosClusterDispatcher [options]\n" + + "\n" + + "Options:\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + + " --name NAME Framework name to show in Mesos UI\n" + + " -m --master MASTER URI for connecting to Mesos master\n" + + " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + + " Zookeeper for persistence\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala new file mode 100644 index 0000000000000..1948226800afe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.Date + +import org.apache.spark.deploy.Command +import org.apache.spark.scheduler.cluster.mesos.MesosClusterRetryState + +/** + * Describes a Spark driver that is submitted from the + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]], to be launched by + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * @param jarUrl URL to the application jar + * @param mem Amount of memory for the driver + * @param cores Number of cores for the driver + * @param supervise Supervise the driver for long running app + * @param command The command to launch the driver. + * @param schedulerProperties Extra properties to pass the Mesos scheduler + */ +private[spark] class MesosDriverDescription( + val name: String, + val jarUrl: String, + val mem: Int, + val cores: Double, + val supervise: Boolean, + val command: Command, + val schedulerProperties: Map[String, String], + val submissionId: String, + val submissionDate: Date, + val retryState: Option[MesosClusterRetryState] = None) + extends Serializable { + + def copy( + name: String = name, + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Double = cores, + supervise: Boolean = supervise, + command: Command = command, + schedulerProperties: Map[String, String] = schedulerProperties, + submissionId: String = submissionId, + submissionDate: Date = submissionDate, + retryState: Option[MesosClusterRetryState] = retryState): MesosDriverDescription = { + new MesosDriverDescription(name, jarUrl, mem, cores, supervise, command, schedulerProperties, + submissionId, submissionDate, retryState) + } + + override def toString: String = s"MesosDriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala new file mode 100644 index 0000000000000..7b2005e0f1237 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + val state = parent.scheduler.getSchedulerState() + val queuedHeaders = Seq("Driver ID", "Submit Date", "Main Class", "Driver Resources") + val driverHeaders = queuedHeaders ++ + Seq("Start Date", "Mesos Slave ID", "State") + val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ + Seq("Last Failed Status", "Next Retry Time", "Attempt Count") + val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) + val launchedTable = UIUtils.listingTable(driverHeaders, driverRow, state.launchedDrivers) + val finishedTable = UIUtils.listingTable(driverHeaders, driverRow, state.finishedDrivers) + val retryTable = UIUtils.listingTable(retryHeaders, retryRow, state.pendingRetryDrivers) + val content = +

Mesos Framework ID: {state.frameworkId}

+
+
+

Queued Drivers:

+ {queuedTable} +

Launched Drivers:

+ {launchedTable} +

Finished Drivers:

+ {finishedTable} +

Supervise drivers waiting for retry:

+ {retryTable} +
+
; + UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + } + + private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { + + {submission.submissionId} + {submission.submissionDate} + {submission.command.mainClass} + cpus: {submission.cores}, mem: {submission.mem} + + } + + private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + + {state.driverDescription.submissionId} + {state.driverDescription.submissionDate} + {state.driverDescription.command.mainClass} + cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} + {state.startDate} + {state.slaveId.getValue} + {stateString(state.mesosTaskStatus)} + + } + + private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + + {submission.submissionId} + {submission.submissionDate} + {submission.command.mainClass} + {submission.retryState.get.lastFailureStatus} + {submission.retryState.get.nextRetry} + {submission.retryState.get.retries} + + } + + private def stateString(status: Option[TaskStatus]): String = { + if (status.isEmpty) { + return "" + } + val sb = new StringBuilder + val s = status.get + sb.append(s"State: ${s.getState}") + if (status.get.hasMessage) { + sb.append(s", Message: ${s.getMessage}") + } + if (status.get.hasHealthy) { + sb.append(s", Healthy: ${s.getHealthy}") + } + if (status.get.hasSource) { + sb.append(s", Source: ${s.getSource}") + } + if (status.get.hasReason) { + sb.append(s", Reason: ${s.getReason}") + } + if (status.get.hasTimestamp) { + sb.append(s", Time: ${s.getTimestamp}") + } + sb.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala new file mode 100644 index 0000000000000..4865d46dbc4ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.{SparkUI, WebUI} + +/** + * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] + */ +private[spark] class MesosClusterUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + dispatcherPublicAddress: String, + val scheduler: MesosClusterScheduler) + extends WebUI(securityManager, port, conf) { + + initialize() + + def activeWebUiUrl: String = "http://" + dispatcherPublicAddress + ":" + boundPort + + override def initialize() { + attachPage(new MesosClusterPage(this)) + attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + } +} + +private object MesosClusterUI { + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index b8fd406fb6f9a..307cebfb4bd09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -30,9 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using a REST protocol. - * This client is intended to communicate with the [[StandaloneRestServer]] and is - * currently used for cluster mode only. + * A client that submits applications to a [[RestSubmissionServer]]. * * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], * where [action] can be one of create, kill, or status. Each type of request is represented in @@ -53,8 +51,10 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[deploy] class StandaloneRestClient extends Logging { - import StandaloneRestClient._ +private[spark] class RestSubmissionClient extends Logging { + import RestSubmissionClient._ + + private val supportedMasterPrefixes = Seq("spark://", "mesos://") /** * Submit an application specified by the parameters in the provided request. @@ -62,7 +62,7 @@ private[deploy] class StandaloneRestClient extends Logging { * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - private[rest] def createSubmission( + def createSubmission( master: String, request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") @@ -107,7 +107,7 @@ private[deploy] class StandaloneRestClient extends Logging { } /** Construct a message that captures the specified parameters for submitting an application. */ - private[rest] def constructSubmitRequest( + def constructSubmitRequest( appResource: String, mainClass: String, appArgs: Array[String], @@ -219,14 +219,23 @@ private[deploy] class StandaloneRestClient extends Logging { /** Return the base URL for communicating with the server, including the protocol version. */ private def getBaseUrl(master: String): String = { - val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + var masterUrl = master + supportedMasterPrefixes.foreach { prefix => + if (master.startsWith(prefix)) { + masterUrl = master.stripPrefix(prefix) + } + } + masterUrl = masterUrl.stripSuffix("/") s"http://$masterUrl/$PROTOCOL_VERSION/submissions" } /** Throw an exception if this is not standalone mode. */ private def validateMaster(master: String): Unit = { - if (!master.startsWith("spark://")) { - throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) } + if (!valid) { + throw new IllegalArgumentException( + "This REST client only supports master URLs that start with " + + "one of the following: " + supportedMasterPrefixes.mkString(",")) } } @@ -295,7 +304,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } -private[rest] object StandaloneRestClient { +private[spark] object RestSubmissionClient { private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -315,7 +324,7 @@ private[rest] object StandaloneRestClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new StandaloneRestClient + val client = new RestSubmissionClient val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) client.createSubmission(master, submitRequest) @@ -323,7 +332,7 @@ private[rest] object StandaloneRestClient { def main(args: Array[String]): Unit = { if (args.size < 2) { - sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") sys.exit(1) } val appResource = args(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala new file mode 100644 index 0000000000000..2e78d03e5c0cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + */ +private[spark] abstract class RestSubmissionServer( + val host: String, + val requestedPort: Int, + val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet + protected val killRequestServlet: KillRequestServlet + protected val statusRequestServlet: StatusRequestServlet + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" + protected lazy val contextToServlet = Map[String, RestServlet]( + s"$baseContext/create/*" -> submitRequestServlet, + s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/status/*" -> statusRequestServlet, + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object RestSubmissionServer { + val PROTOCOL_VERSION = RestSubmissionClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class RestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse +} + +/** + * A servlet for handling status requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class StatusRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse +} + +/** + * A servlet for handling submit requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class SubmitRequestServlet extends RestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + protected def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends RestServlet { + private val serverVersion = RestSubmissionServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 2d6b8d4204795..502b9bb701ccf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,26 +18,16 @@ package org.apache.spark.deploy.rest import java.io.File -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} - -import scala.io.Source +import javax.servlet.http.HttpServletResponse import akka.actor.ActorRef -import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} -import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** - * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * A server that responds to requests submitted by the [[RestSubmissionClient]]. * This is intended to be embedded in the standalone Master and used in cluster mode only. * * This server responds with different HTTP codes depending on the situation: @@ -54,173 +44,31 @@ import org.apache.spark.deploy.ClientArguments._ * * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to + * @param masterConf the conf used by the Master * @param masterActor reference to the Master actor to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to - * @param masterConf the conf used by the Master */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends Logging { - - import StandaloneRestServer._ - - private var _server: Option[Server] = None - - // A mapping from URL prefixes to servlets that serve them. Exposed for testing. - protected val baseContext = s"/$PROTOCOL_VERSION/submissions" - protected val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), - s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), - s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), - "/*" -> new ErrorServlet // default handler - ) - - /** Start the server and return the bound port. */ - def start(): Int = { - val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) - _server = Some(server) - logInfo(s"Started REST server for submitting applications on port $boundPort") - boundPort - } - - /** - * Map the servlets to their corresponding contexts and attach them to a server. - * Return a 2-tuple of the started server and the bound port. - */ - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, startPort)) - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val mainHandler = new ServletContextHandler - mainHandler.setContextPath("/") - contextToServlet.foreach { case (prefix, servlet) => - mainHandler.addServlet(new ServletHolder(servlet), prefix) - } - server.setHandler(mainHandler) - server.start() - val boundPort = server.getConnectors()(0).getLocalPort - (server, boundPort) - } - - def stop(): Unit = { - _server.foreach(_.stop()) - } -} - -private[rest] object StandaloneRestServer { - val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION - val SC_UNKNOWN_PROTOCOL_VERSION = 468 -} - -/** - * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. - */ -private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { - - /** - * Serialize the given response message to JSON and send it through the response servlet. - * This validates the response before sending it to ensure it is properly constructed. - */ - protected def sendResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): Unit = { - val message = validateResponse(responseMessage, responseServlet) - responseServlet.setContentType("application/json") - responseServlet.setCharacterEncoding("utf-8") - responseServlet.getWriter.write(message.toJson) - } - - /** - * Return any fields in the client request message that the server does not know about. - * - * The mechanism for this is to reconstruct the JSON on the server side and compare the - * diff between this JSON and the one generated on the client side. Any fields that are - * only in the client JSON are treated as unexpected. - */ - protected def findUnknownFields( - requestJson: String, - requestMessage: SubmitRestProtocolMessage): Array[String] = { - val clientSideJson = parse(requestJson) - val serverSideJson = parse(requestMessage.toJson) - val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) - unknown match { - case j: JObject => j.obj.map { case (k, _) => k }.toArray - case _ => Array.empty[String] // No difference - } - } - - /** Return a human readable String representation of the exception. */ - protected def formatException(e: Throwable): String = { - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"$e\n$stackTraceString" - } - - /** Construct an error message to signal the fact that an exception has been thrown. */ - protected def handleError(message: String): ErrorResponse = { - val e = new ErrorResponse - e.serverSparkVersion = sparkVersion - e.message = message - e - } - - /** - * Parse a submission ID from the relative path, assuming it is the first part of the path. - * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. - * The returned submission ID cannot be empty. If the path is unexpected, return None. - */ - protected def parseSubmissionId(path: String): Option[String] = { - if (path == null || path.isEmpty) { - None - } else { - path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) - } - } - - /** - * Validate the response to ensure that it is correctly constructed. - * - * If it is, simply return the message as is. Otherwise, return an error response instead - * to propagate the exception back to the client and set the appropriate error code. - */ - private def validateResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - try { - responseMessage.validate() - responseMessage - } catch { - case e: Exception => - responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) - handleError("Internal server error: " + formatException(e)) - } - } + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + protected override val killRequestServlet = + new StandaloneKillRequestServlet(masterActor, masterConf) + protected override val statusRequestServlet = + new StandaloneStatusRequestServlet(masterActor, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, have the Master kill the corresponding - * driver and return an appropriate response to the client. Otherwise, return error. - */ - protected override def doPost( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleKill).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in kill request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -238,23 +86,8 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, request the status of the corresponding - * driver from the Master and include it in the response. Otherwise, return error. - */ - protected override def doGet( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleStatus).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in status request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -276,71 +109,11 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ -private[rest] class SubmitRequestServlet( +private[rest] class StandaloneSubmitRequestServlet( masterActor: ActorRef, masterUrl: String, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * Submit an application to the Master with parameters specified in the request. - * - * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. - * If the request is successfully processed, return an appropriate response to the - * client indicating so. Otherwise, return error instead. - */ - protected override def doPost( - requestServlet: HttpServletRequest, - responseServlet: HttpServletResponse): Unit = { - val responseMessage = - try { - val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) - // The response should have already been validated on the client. - // In case this is not true, validate it ourselves to avoid potential NPEs. - requestMessage.validate() - handleSubmit(requestMessageJson, requestMessage, responseServlet) - } catch { - // The client failed to provide a valid JSON, so this is not our fault - case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Malformed request: " + formatException(e)) - } - sendResponse(responseMessage, responseServlet) - } - - /** - * Handle the submit request and construct an appropriate response to return to the client. - * - * This assumes that the request message is already successfully validated. - * If the request message is not of the expected type, return error to the client. - */ - private def handleSubmit( - requestMessageJson: String, - requestMessage: SubmitRestProtocolMessage, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - requestMessage match { - case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) - val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - val submitResponse = new CreateSubmissionResponse - submitResponse.serverSparkVersion = sparkVersion - submitResponse.message = response.message - submitResponse.success = response.success - submitResponse.submissionId = response.driverId.orNull - val unknownFields = findUnknownFields(requestMessageJson, requestMessage) - if (unknownFields.nonEmpty) { - // If there are fields that the server does not know about, warn the client - submitResponse.unknownFields = unknownFields - } - submitResponse - case unexpected => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError(s"Received message of unexpected type ${unexpected.messageType}.") - } - } + extends SubmitRequestServlet { /** * Build a driver description from the fields specified in the submit request. @@ -389,50 +162,37 @@ private[rest] class SubmitRequestServlet( new DriverDescription( appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } -} -/** - * A default servlet that handles error cases that are not captured by other servlets. - */ -private class ErrorServlet extends StandaloneRestServlet { - private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION - - /** Service a faulty request by returning an appropriate error message to the client. */ - protected override def service( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val path = request.getPathInfo - val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList - var versionMismatch = false - var msg = - parts match { - case Nil => - // http://host:port/ - "Missing protocol version." - case `serverVersion` :: Nil => - // http://host:port/correct-version - "Missing the /submissions prefix." - case `serverVersion` :: "submissions" :: tail => - // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, or /status." - case unknownVersion :: tail => - // http://host:port/unknown-version/* - versionMismatch = true - s"Unknown protocol version '$unknownVersion'." - case _ => - // never reached - s"Malformed path $path." - } - msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." - val error = handleError(msg) - // If there is a version mismatch, include the highest protocol version that - // this server supports in case the client wants to retry with our version - if (versionMismatch) { - error.highestProtocolVersion = serverVersion - response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = RpcUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") } - sendResponse(error, response) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index d80abdf15fb34..0d50a768942ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -61,7 +61,7 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertProperty[Boolean](key, "boolean", _.toBoolean) private def assertPropertyIsNumeric(key: String): Unit = - assertProperty[Int](key, "numeric", _.toInt) + assertProperty[Double](key, "numeric", _.toDouble) private def assertPropertyIsMemory(key: String): Unit = assertProperty[Int](key, "memory", Utils.memoryStringToMb) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 8fde8c142a4c1..0e226ee294cab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -35,7 +35,7 @@ private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtoc /** * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. */ -private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -46,7 +46,7 @@ private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse /** * A response to a kill request in the REST application submission protocol. */ -private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -58,7 +58,7 @@ private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { /** * A response to a status request in the REST application submission protocol. */ -private[rest] class SubmissionStatusResponse extends SubmitRestProtocolResponse { +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { var submissionId: String = null var driverState: String = null var workerId: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala new file mode 100644 index 0000000000000..fd17a980c9319 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest.mesos + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.atomic.AtomicLong +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * All requests are forwarded to + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * This is intended to be used in Mesos cluster mode only. + * For more details about the REST submission please refer to [[RestSubmissionServer]] javadocs. + */ +private[spark] class MesosRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + scheduler: MesosClusterScheduler) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new MesosSubmitRequestServlet(scheduler, masterConf) + protected override val killRequestServlet = + new MesosKillRequestServlet(scheduler, masterConf) + protected override val statusRequestServlet = + new MesosStatusRequestServlet(scheduler, masterConf) +} + +private[deploy] class MesosSubmitRequestServlet( + scheduler: MesosClusterScheduler, + conf: SparkConf) + extends SubmitRequestServlet { + + private val DEFAULT_SUPERVISE = false + private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_CORES = 1.0 + + private val nextDriverNumber = new AtomicLong(0) + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def newDriverId(submitDate: Date): String = { + "driver-%s-%04d".format( + createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that launches a mesos framework for the job. + * This does not currently consider fields used by python applications since python + * is not supported in mesos cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) + + // Construct driver description + val conf = new SparkConf(false).setAll(sparkProperties) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) + val submitDate = new Date() + val submissionId = newDriverId(submitDate) + + new MesosDriverDescription( + name, appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, + command, request.sparkProperties, submissionId, submitDate) + } + + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val s = scheduler.submitDriver(driverDescription) + s.serverSparkVersion = sparkVersion + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + s.unknownFields = unknownFields + } + s + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} + +private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends KillRequestServlet { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = scheduler.killDriver(submissionId) + k.serverSparkVersion = sparkVersion + k + } +} + +private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends StatusRequestServlet { + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val d = scheduler.getDriverStatus(submissionId) + d.serverSparkVersion = sparkVersion + d + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3ee2eb69e8a4e..8f3cc54051048 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,6 +34,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem @@ -61,7 +62,7 @@ private[worker] class Worker( assert (port > 0) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -85,10 +86,10 @@ private[worker] class Worker( private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - private val CLEANUP_INTERVAL_MILLIS = + private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") @@ -112,7 +113,7 @@ private[worker] class Worker( } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } - + var workDir: File = null val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] @@ -122,7 +123,7 @@ private[worker] class Worker( val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -134,7 +135,7 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - + private var registrationRetryTimer: Option[Cancellable] = None var coresUsed = 0 diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8af46f3327adb..79aed90b53e2f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -57,7 +57,7 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => driver = Some(ref) - ref.sendWithReply[RegisteredExecutor.type]( + ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) } onComplete { case Success(msg) => Utils.tryLogNonFatalError { @@ -154,7 +154,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { executorConf, new SecurityManager(executorConf)) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ + val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f57e215c3f2ed..8f916e0502ecb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -178,6 +179,7 @@ private[spark] class Executor( } override def run(): Unit = { + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() @@ -190,6 +192,7 @@ private[spark] class Executor( val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. @@ -206,7 +209,21 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + val value = try { + task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + } finally { + // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; + // when changing this, make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. @@ -424,7 +441,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0709b6d689e86..0756cdb2ed8e6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -97,7 +97,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -107,7 +107,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -137,7 +137,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -153,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d80d94a588346..330255f89247f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -407,11 +407,26 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T]( - this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) + randomSampleWithRange(x(0), x(1), seed) }.toArray } + /** + * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability + * range. + * @param lb lower bound to use for the Bernoulli sampler + * @param ub upper bound to use for the Bernoulli sampler + * @param seed the seed for the Random number generator + * @return A random sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { + this.mapPartitionsWithIndex { case (index, partition) => + val sampler = new BernoulliCellSampler[T](lb, ub) + sampler.setSeed(seed + index) + sampler.sample(partition) + } + } + /** * Return a fixed-size sampled subset of this RDD in an array * diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala new file mode 100644 index 0000000000000..3e5b64265e919 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala new file mode 100644 index 0000000000000..d2b2baef1d8c4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Invoked when any exception is thrown during handling messages. + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(_self) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala new file mode 100644 index 0000000000000..69181edb9ad44 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.FiniteDuration +import scala.reflect.ClassTag + +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SparkException, Logging, SparkConf} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = ask[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a5336b7563802..12b6b28d4d7ec 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -20,13 +20,41 @@ package org.apache.spark.rpc import java.net.URI import scala.concurrent.{Await, Future} -import scala.concurrent.duration._ import scala.language.postfixOps -import scala.reflect.ClassTag -import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + + /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote @@ -112,6 +140,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { def uriOf(systemName: String, address: RpcAddress, endpointName: String): String } + private[spark] case class RpcEnvConfig( conf: SparkConf, name: String, @@ -119,261 +148,9 @@ private[spark] case class RpcEnvConfig( port: Int, securityManager: SecurityManager) -/** - * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor - * so that it can be created via Reflection. - */ -private[spark] object RpcEnv { - - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] - } - - def create( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly - val config = RpcEnvConfig(conf, name, host, port, securityManager) - getRpcEnvFactory(conf).create(config) - } - -} - -/** - * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be - * created using Reflection. - */ -private[spark] trait RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv -} /** - * An end point for the RPC that defines what functions to trigger given a message. - * - * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. - * - * The lift-cycle will be: - * - * constructor onStart receive* onStop - * - * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use - * [[ThreadSafeRpcEndpoint]] - * - * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be - * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. - */ -private[spark] trait RpcEndpoint { - - /** - * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. - */ - val rpcEnv: RpcEnv - - /** - * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is - * called. And `self` will become `null` when `onStop` is called. - * - * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not - * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. - */ - final def self: RpcEndpointRef = { - require(rpcEnv != null, "rpcEnv has not been initialized") - rpcEnv.endpointRef(this) - } - - /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. - */ - def receive: PartialFunction[Any, Unit] = { - case _ => throw new SparkException(self + " does not implement 'receive'") - } - - /** - * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. - */ - def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case _ => context.sendFailure(new SparkException(self + " won't reply anything")) - } - - /** - * Call onError when any exception is thrown during handling messages. - * - * @param cause - */ - def onError(cause: Throwable): Unit = { - // By default, throw e and let RpcEnv handle it - throw cause - } - - /** - * Invoked before [[RpcEndpoint]] starts to handle any message. - */ - def onStart(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when [[RpcEndpoint]] is stopping. - */ - def onStop(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is connected to the current node. - */ - def onConnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is lost. - */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. - */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * A convenient method to stop [[RpcEndpoint]]. - */ - final def stop(): Unit = { - val _self = self - if (_self != null) { - rpcEnv.stop(_self) - } - } -} - -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -trait ThreadSafeRpcEndpoint extends RpcEndpoint - -/** - * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. - */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) - extends Serializable with Logging { - - private[this] val maxRetries = RpcUtils.numRetries(conf) - private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) - - /** - * return the address for the [[RpcEndpointRef]] - */ - def address: RpcAddress - - def name: String - - /** - * Sends a one-way asynchronous message. Fire-and-forget semantics. - */ - def send(message: Any): Unit - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within a default timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = - sendWithReply(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within the specified timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] - - /** - * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default - * timeout, or throw a SparkException if this fails even after the default number of retries. - * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this - * method retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a - * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method - * retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @param timeout the timeout duration - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { - // TODO: Consider removing multiple attempts - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = sendWithReply[T](message, timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - Thread.sleep(retryWaitMs) - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - -} - -/** - * Represent a host with a port + * Represents a host and port. */ private[spark] case class RpcAddress(host: String, port: Int) { // TODO do we need to add the type of RpcEnv in the address? @@ -383,6 +160,7 @@ private[spark] case class RpcAddress(host: String, port: Int) { override val toString: String = hostPort } + private[spark] object RpcAddress { /** @@ -404,26 +182,3 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } - -/** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe - * and can be called in any thread. - */ -private[spark] trait RpcCallContext { - - /** - * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] - * will be called. - */ - def reply(response: Any): Unit - - /** - * Report a failure to the sender. - */ - def sendFailure(e: Throwable): Unit - - /** - * The sender of this message. - */ - def sender: RpcEndpointRef -} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 652e52f2b2e73..ba0d468f111ef 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -293,7 +293,7 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { case msg @ AkkaMessage(message, reply) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c4bff4e83afc..05b8ab0d0a1f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -28,12 +28,15 @@ import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal +import org.apache.commons.lang3.SerializationUtils + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -166,7 +169,7 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) } @@ -509,7 +512,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, + SerializationUtils.clone(properties))) waiter } @@ -546,7 +550,8 @@ class DAGScheduler( val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, + SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } @@ -643,8 +648,15 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, - attemptNumber = 0, runningLocally = true) + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskContext = + new TaskContextImpl( + job.finalStage.id, + job.partitions(0), + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + runningLocally = true) TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) @@ -652,6 +664,16 @@ class DAGScheduler( } finally { taskContext.markTaskCompleted() TaskContext.unset() + // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, + // make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") + } else { + logError(s"Managed memory leak detected; size = $freedMemory bytes") + } + } } } catch { case e: Exception => @@ -686,8 +708,11 @@ class DAGScheduler( private[scheduler] def handleJobGroupCancelled(groupId: String) { // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter(activeJob => - Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) + val activeInGroup = activeJobs.filter { activeJob => + Option(activeJob.properties).exists { + _.getProperty(SparkContext.SPARK_JOB_GROUP_ID) == groupId + } + } val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 7c184b1dcb308..0b1d47cff3746 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -85,7 +85,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { val msg = AskPermissionToCommitOutput(stage, partition, attempt) coordinatorRef match { case Some(endpointRef) => - endpointRef.askWithReply[Boolean](msg) + endpointRef.askWithRetry[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index b09b19e2ac9e7..586d1e06204c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @return the result of the task */ final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) + context = new TaskContextImpl( + stageId = stageId, + partitionId = partitionId, + taskAttemptId = taskAttemptId, + attemptNumber = attemptNumber, + taskMemoryManager = taskMemoryManager, + runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() @@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex } } + private var taskMemoryManager: TaskMemoryManager = _ + + def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { + this.taskMemoryManager = taskMemoryManager + } + def runTask(context: TaskContext): T def preferredLocations: Seq[TaskLocation] = Nil diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9656fb76858ea..7352fa1fe9ebd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -252,7 +252,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp try { if (driverEndpoint != null) { logInfo("Shutting down all executors") - driverEndpoint.askWithReply[Boolean](StopExecutors) + driverEndpoint.askWithRetry[Boolean](StopExecutors) } } catch { case e: Exception => @@ -264,7 +264,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp stopExecutors() try { if (driverEndpoint != null) { - driverEndpoint.askWithReply[Boolean](StopDriver) + driverEndpoint.askWithRetry[Boolean](StopDriver) } } catch { case e: Exception => @@ -287,7 +287,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index d987c7d563579..2a3a5d925d06f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -53,14 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -115,7 +115,7 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint match { case Some(am) => Future { - context.reply(am.askWithReply[Boolean](r)) + context.reply(am.askWithRetry[Boolean](r)) } onFailure { case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) @@ -130,7 +130,7 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint match { case Some(am) => Future { - context.reply(am.askWithReply[Boolean](k)) + context.reply(am.askWithRetry[Boolean](k)) } onFailure { case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 82f652dae0378..3412301e64fd7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} -import java.util.Collections +import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -49,17 +46,10 @@ private[spark] class CoarseMesosSchedulerBackend( master: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler - with Logging { + with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt @@ -87,26 +77,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } def createCommand(offer: Offer, numCores: Int): CommandInfo = { @@ -150,8 +122,10 @@ private[spark] class CoarseMesosSchedulerBackend( conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - val uri = conf.get("spark.executor.uri", null) - if (uri == null) { + val uri = conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + + if (uri.isEmpty) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" @@ -164,7 +138,7 @@ private[spark] class CoarseMesosSchedulerBackend( } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + @@ -173,7 +147,7 @@ private[spark] class CoarseMesosSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } command.build() } @@ -183,18 +157,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } + markRegistered() } override def disconnected(d: SchedulerDriver) {} @@ -245,14 +208,6 @@ private[spark] class CoarseMesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Build a Mesos resource protobuf object */ private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() @@ -284,7 +239,8 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + // In case we'd rejected everything before but have now lost a node + mesosDriver.reviveOffers() } } } @@ -296,8 +252,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def stop() { super.stop() - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala new file mode 100644 index 0000000000000..3efc536f1456c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConversions._ + +import org.apache.curator.framework.CuratorFramework +import org.apache.zookeeper.CreateMode +import org.apache.zookeeper.KeeperException.NoNodeException + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.util.Utils + +/** + * Persistence engine factory that is responsible for creating new persistence engines + * to store Mesos cluster mode state. + */ +private[spark] abstract class MesosClusterPersistenceEngineFactory(conf: SparkConf) { + def createEngine(path: String): MesosClusterPersistenceEngine +} + +/** + * Mesos cluster persistence engine is responsible for persisting Mesos cluster mode + * specific state, so that on failover all the state can be recovered and the scheduler + * can resume managing the drivers. + */ +private[spark] trait MesosClusterPersistenceEngine { + def persist(name: String, obj: Object): Unit + def expunge(name: String): Unit + def fetch[T](name: String): Option[T] + def fetchAll[T](): Iterable[T] +} + +/** + * Zookeeper backed persistence engine factory. + * All Zk engines created from this factory shares the same Zookeeper client, so + * all of them reuses the same connection pool. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) + extends MesosClusterPersistenceEngineFactory(conf) { + + lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + + def createEngine(path: String): MesosClusterPersistenceEngine = { + new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) + } +} + +/** + * Black hole persistence engine factory that creates black hole + * persistence engines, which stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngineFactory + extends MesosClusterPersistenceEngineFactory(null) { + def createEngine(path: String): MesosClusterPersistenceEngine = { + new BlackHoleMesosClusterPersistenceEngine + } +} + +/** + * Black hole persistence engine that stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngine extends MesosClusterPersistenceEngine { + override def persist(name: String, obj: Object): Unit = {} + override def fetch[T](name: String): Option[T] = None + override def expunge(name: String): Unit = {} + override def fetchAll[T](): Iterable[T] = Iterable.empty[T] +} + +/** + * Zookeeper based Mesos cluster persistence engine, that stores cluster mode state + * into Zookeeper. Each engine object is operating under one folder in Zookeeper, but + * reuses a shared Zookeeper client. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngine( + baseDir: String, + zk: CuratorFramework, + conf: SparkConf) + extends MesosClusterPersistenceEngine with Logging { + private val WORKING_DIR = + conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir + + SparkCuratorUtil.mkdir(zk, WORKING_DIR) + + def path(name: String): String = { + WORKING_DIR + "/" + name + } + + override def expunge(name: String): Unit = { + zk.delete().forPath(path(name)) + } + + override def persist(name: String, obj: Object): Unit = { + val serialized = Utils.serialize(obj) + val zkPath = path(name) + zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized) + } + + override def fetch[T](name: String): Option[T] = { + val zkPath = path(name) + + try { + val fileData = zk.getData().forPath(zkPath) + Some(Utils.deserialize[T](fileData)) + } catch { + case e: NoNodeException => None + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(zkPath) + None + } + } + } + + override def fetchAll[T](): Iterable[T] = { + zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala new file mode 100644 index 0000000000000..0396e62be5309 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.io.File +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.Protos.TaskStatus.Reason +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} + + +/** + * Tracks the current state of a Mesos Task that runs a Spark driver. + * @param driverDescription Submitted driver description from + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] + * @param taskId Mesos TaskID generated for the task + * @param slaveId Slave ID that the task is assigned to + * @param mesosTaskStatus The last known task status update. + * @param startDate The date the task was launched + */ +private[spark] class MesosClusterSubmissionState( + val driverDescription: MesosDriverDescription, + val taskId: TaskID, + val slaveId: SlaveID, + var mesosTaskStatus: Option[TaskStatus], + var startDate: Date) + extends Serializable { + + def copy(): MesosClusterSubmissionState = { + new MesosClusterSubmissionState( + driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + } +} + +/** + * Tracks the retry state of a driver, which includes the next time it should be scheduled + * and necessary information to do exponential backoff. + * This class is not thread-safe, and we expect the caller to handle synchronizing state. + * @param lastFailureStatus Last Task status when it failed. + * @param retries Number of times it has been retried. + * @param nextRetry Time at which it should be retried next + * @param waitTime The amount of time driver is scheduled to wait until next retry. + */ +private[spark] class MesosClusterRetryState( + val lastFailureStatus: TaskStatus, + val retries: Int, + val nextRetry: Date, + val waitTime: Int) extends Serializable { + def copy(): MesosClusterRetryState = + new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime) +} + +/** + * The full state of the cluster scheduler, currently being used for displaying + * information on the UI. + * @param frameworkId Mesos Framework id for the cluster scheduler. + * @param masterUrl The Mesos master url + * @param queuedDrivers All drivers queued to be launched + * @param launchedDrivers All launched or running drivers + * @param finishedDrivers All terminated drivers + * @param pendingRetryDrivers All drivers pending to be retried + */ +private[spark] class MesosClusterSchedulerState( + val frameworkId: String, + val masterUrl: Option[String], + val queuedDrivers: Iterable[MesosDriverDescription], + val launchedDrivers: Iterable[MesosClusterSubmissionState], + val finishedDrivers: Iterable[MesosClusterSubmissionState], + val pendingRetryDrivers: Iterable[MesosDriverDescription]) + +/** + * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode + * as Mesos tasks in a Mesos cluster. + * All drivers are launched asynchronously by the framework, which will eventually be launched + * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * sandbox which is accessible by visiting the Mesos UI. + * This scheduler supports recovery by persisting all its state and performs task reconciliation + * on recover, which gets all the latest state for all the drivers from Mesos master. + */ +private[spark] class MesosClusterScheduler( + engineFactory: MesosClusterPersistenceEngineFactory, + conf: SparkConf) + extends Scheduler with MesosSchedulerUtils { + var frameworkUrl: String = _ + private val metricsSystem = + MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf)) + private val master = conf.get("spark.master") + private val appName = conf.get("spark.app.name") + private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) + private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) + private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val schedulerState = engineFactory.createEngine("scheduler") + private val stateLock = new ReentrantLock() + private val finishedDrivers = + new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) + private var frameworkId: String = null + // Holds all the launched drivers and current launch state, keyed by driver id. + private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() + // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // All drivers that are loaded after failover are added here, as we need get the latest + // state of the tasks from Mesos. + private val pendingRecover = new mutable.HashMap[String, SlaveID]() + // Stores all the submitted drivers that hasn't been launched. + private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() + // All supervised drivers that are waiting to retry after termination. + private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() + private val queuedDriversState = engineFactory.createEngine("driverQueue") + private val launchedDriversState = engineFactory.createEngine("launchedDrivers") + private val pendingRetryDriversState = engineFactory.createEngine("retryList") + // Flag to mark if the scheduler is ready to be called, which is until the scheduler + // is registered with Mesos master. + @volatile protected var ready = false + private var masterInfo: Option[MasterInfo] = None + + def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { + val c = new CreateSubmissionResponse + if (!ready) { + c.success = false + c.message = "Scheduler is not ready to take requests" + return c + } + + stateLock.synchronized { + if (isQueueFull()) { + c.success = false + c.message = "Already reached maximum submission size" + return c + } + c.submissionId = desc.submissionId + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + c.success = true + } + c + } + + def killDriver(submissionId: String): KillSubmissionResponse = { + val k = new KillSubmissionResponse + if (!ready) { + k.success = false + k.message = "Scheduler is not ready to take requests" + return k + } + k.submissionId = submissionId + stateLock.synchronized { + // We look for the requested driver in the following places: + // 1. Check if submission is running or launched. + // 2. Check if it's still queued. + // 3. Check if it's in the retry list. + // 4. Check if it has already completed. + if (launchedDrivers.contains(submissionId)) { + val task = launchedDrivers(submissionId) + mesosDriver.killTask(task.taskId) + k.success = true + k.message = "Killing running driver" + } else if (removeFromQueuedDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's still pending" + } else if (removeFromPendingRetryDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's being retried" + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + k.success = false + k.message = "Driver already terminated" + } else { + k.success = false + k.message = "Cannot find driver" + } + } + k + } + + def getDriverStatus(submissionId: String): SubmissionStatusResponse = { + val s = new SubmissionStatusResponse + if (!ready) { + s.success = false + s.message = "Scheduler is not ready to take requests" + return s + } + s.submissionId = submissionId + stateLock.synchronized { + if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "QUEUED" + } else if (launchedDrivers.contains(submissionId)) { + s.success = true + s.driverState = "RUNNING" + launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "FINISHED" + finishedDrivers + .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus + .foreach(state => s.message = state.toString) + } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { + val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .get.retryState.get.lastFailureStatus + s.success = true + s.driverState = "RETRYING" + s.message = status.toString + } else { + s.success = false + s.driverState = "NOT_FOUND" + } + } + s + } + + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity + + /** + * Recover scheduler state that is persisted. + * We still need to do task reconciliation to be up to date of the latest task states + * as it might have changed while the scheduler is failing over. + */ + private def recoverState(): Unit = { + stateLock.synchronized { + launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => + launchedDrivers(state.taskId.getValue) = state + pendingRecover(state.taskId.getValue) = state.slaveId + } + queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) + // There is potential timing issue where a queued driver might have been launched + // but the scheduler shuts down before the queued driver was able to be removed + // from the queue. We try to mitigate this issue by walking through all queued drivers + // and remove if they're already launched. + queuedDrivers + .filter(d => launchedDrivers.contains(d.submissionId)) + .foreach(d => removeFromQueuedDrivers(d.submissionId)) + pendingRetryDriversState.fetchAll[MesosDriverDescription]() + .foreach(s => pendingRetryDrivers += s) + // TODO: Consider storing finished drivers so we can show them on the UI after + // failover. For now we clear the history on each recovery. + finishedDrivers.clear() + } + } + + /** + * Starts the cluster scheduler and wait until the scheduler is registered. + * This also marks the scheduler to be ready for requests. + */ + def start(): Unit = { + // TODO: Implement leader election to make sure only one framework running in the cluster. + val fwId = schedulerState.fetch[String]("frameworkId") + val builder = FrameworkInfo.newBuilder() + .setUser(Utils.getCurrentUserName()) + .setName(appName) + .setWebuiUrl(frameworkUrl) + .setCheckpoint(true) + .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash + fwId.foreach { id => + builder.setId(FrameworkID.newBuilder().setValue(id).build()) + frameworkId = id + } + recoverState() + metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) + metricsSystem.start() + startScheduler(master, MesosClusterScheduler.this, builder.build()) + ready = true + } + + def stop(): Unit = { + ready = false + metricsSystem.report() + metricsSystem.stop() + mesosDriver.stop(true) + } + + override def registered( + driver: SchedulerDriver, + newFrameworkId: FrameworkID, + masterInfo: MasterInfo): Unit = { + logInfo("Registered as framework ID " + newFrameworkId.getValue) + if (newFrameworkId.getValue != frameworkId) { + frameworkId = newFrameworkId.getValue + schedulerState.persist("frameworkId", frameworkId) + } + markRegistered() + + stateLock.synchronized { + this.masterInfo = Some(masterInfo) + if (!pendingRecover.isEmpty) { + // Start task reconciliation if we need to recover. + val statuses = pendingRecover.collect { + case (taskId, slaveId) => + val newStatus = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(slaveId) + .setState(MesosTaskState.TASK_STAGING) + .build() + launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + .getOrElse(newStatus) + } + // TODO: Page the status updates to avoid trying to reconcile + // a large amount of tasks at once. + driver.reconcileTasks(statuses) + } + } + } + + private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = { + val appJar = CommandInfo.URI.newBuilder() + .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() + val builder = CommandInfo.newBuilder().addUris(appJar) + val entries = + (conf.getOption("spark.executor.extraLibraryPath").toList ++ + desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { + Utils.libraryPathEnvPrefix(entries) + } else { + "" + } + val envBuilder = Environment.newBuilder() + desc.command.environment.foreach { case (k, v) => + envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build()) + } + // Pass all spark properties to executor. + val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") + envBuilder.addVariables( + Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) + val cmdOptions = generateCmdOption(desc) + val executorUri = desc.schedulerProperties.get("spark.executor.uri") + .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + val appArguments = desc.command.arguments.mkString(" ") + val cmd = if (executorUri.isDefined) { + builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) + val folderBasename = executorUri.get.split('/').last.split('.').head + val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" + val cmdJar = s"../${desc.jarUrl.split("/").last}" + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } else { + val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") + .orElse(conf.getOption("spark.home")) + .orElse(Option(System.getenv("SPARK_HOME"))) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdJar = desc.jarUrl.split("/").last + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } + builder.setValue(cmd) + builder.setEnvironment(envBuilder.build()) + builder.build() + } + + private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + var options = Seq( + "--name", desc.schedulerProperties("spark.app.name"), + "--class", desc.command.mainClass, + "--master", s"mesos://${conf.get("spark.master")}", + "--driver-cores", desc.cores.toString, + "--driver-memory", s"${desc.mem}M") + desc.schedulerProperties.get("spark.executor.memory").map { v => + options ++= Seq("--executor-memory", v) + } + desc.schedulerProperties.get("spark.cores.max").map { v => + options ++= Seq("--total-executor-cores", v) + } + options + } + + private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + override def toString(): String = { + s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + } + } + + /** + * This method takes all the possible candidates and attempt to schedule them with Mesos offers. + * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled + * logic on each task. + */ + private def scheduleTasks( + candidates: Seq[MesosDriverDescription], + afterLaunchCallback: (String) => Boolean, + currentOffers: List[ResourceOffer], + tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = { + for (submission <- candidates) { + val driverCpu = submission.cores + val driverMem = submission.mem + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val offerOption = currentOffers.find { o => + o.cpu >= driverCpu && o.mem >= driverMem + } + if (offerOption.isEmpty) { + logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + + s"cpu: $driverCpu, mem: $driverMem") + } else { + val offer = offerOption.get + offer.cpu -= driverCpu + offer.mem -= driverMem + val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() + val cpuResource = Resource.newBuilder() + .setName("cpus").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() + val memResource = Resource.newBuilder() + .setName("mem").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val commandInfo = buildDriverCommand(submission) + val appName = submission.schedulerProperties("spark.app.name") + val taskInfo = TaskInfo.newBuilder() + .setTaskId(taskId) + .setName(s"Driver for $appName") + .setSlaveId(offer.offer.getSlaveId) + .setCommand(commandInfo) + .addResources(cpuResource) + .addResources(memResource) + .build() + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + queuedTasks += taskInfo + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + submission.submissionId) + val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + None, new Date()) + launchedDrivers(submission.submissionId) = newState + launchedDriversState.persist(submission.submissionId, newState) + afterLaunchCallback(submission.submissionId) + } + } + } + + override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { + val currentOffers = offers.map { o => + new ResourceOffer( + o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) + }.toList + logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() + val currentTime = new Date() + + stateLock.synchronized { + // We first schedule all the supervised drivers that are ready to retry. + // This list will be empty if none of the drivers are marked as supervise. + val driversToRetry = pendingRetryDrivers.filter { d => + d.retryState.get.nextRetry.before(currentTime) + } + scheduleTasks( + driversToRetry, + removeFromPendingRetryDrivers, + currentOffers, + tasks) + // Then we walk through the queued drivers and try to schedule them. + scheduleTasks( + queuedDrivers, + removeFromQueuedDrivers, + currentOffers, + tasks) + } + tasks.foreach { case (offerId, tasks) => + driver.launchTasks(Collections.singleton(offerId), tasks) + } + offers + .filter(o => !tasks.keySet.contains(o.getId)) + .foreach(o => driver.declineOffer(o.getId)) + } + + def getSchedulerState(): MesosClusterSchedulerState = { + def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + stateLock.synchronized { + new MesosClusterSchedulerState( + frameworkId, + masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"), + copyBuffer(queuedDrivers), + launchedDrivers.values.map(_.copy()).toList, + finishedDrivers.map(_.copy()).toList, + copyBuffer(pendingRetryDrivers)) + } + } + + override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {} + override def disconnected(driver: SchedulerDriver): Unit = {} + override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { + logInfo(s"Framework re-registered with master ${masterInfo.getId}") + } + override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def error(driver: SchedulerDriver, error: String): Unit = { + logError("Error received: " + error) + } + + /** + * Check if the task state is a recoverable state that we can relaunch the task. + * Task state like TASK_ERROR are not relaunchable state since it wasn't able + * to be validated by Mesos. + */ + private def shouldRelaunch(state: MesosTaskState): Boolean = { + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { + val taskId = status.getTaskId.getValue + stateLock.synchronized { + if (launchedDrivers.contains(taskId)) { + if (status.getReason == Reason.REASON_RECONCILIATION && + !pendingRecover.contains(taskId)) { + // Task has already received update and no longer requires reconciliation. + return + } + val state = launchedDrivers(taskId) + // Check if the driver is supervise enabled and can be relaunched. + if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { + removeFromLaunchedDrivers(taskId) + val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState + val (retries, waitTimeSec) = retryState + .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } + .getOrElse{ (1, 1) } + val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) + + val newDriverDescription = state.driverDescription.copy( + retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) + pendingRetryDrivers += newDriverDescription + pendingRetryDriversState.persist(taskId, newDriverDescription) + } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { + removeFromLaunchedDrivers(taskId) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + state.mesosTaskStatus = Option(status) + } else { + logError(s"Unable to find driver $taskId in status update") + } + } + } + + override def frameworkMessage( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + message: Array[Byte]): Unit = {} + + override def executorLost( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + status: Int): Unit = {} + + private def removeFromQueuedDrivers(id: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + queuedDrivers.remove(index) + queuedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromLaunchedDrivers(id: String): Boolean = { + if (launchedDrivers.remove(id).isDefined) { + launchedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromPendingRetryDrivers(id: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + pendingRetryDrivers.remove(index) + pendingRetryDriversState.expunge(id) + true + } else { + false + } + } + + def getQueuedDriversSize: Int = queuedDrivers.size + def getLaunchedDriversSize: Int = launchedDrivers.size + def getPendingRetryDriversSize: Int = pendingRetryDrivers.size +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala new file mode 100644 index 0000000000000..1fe94974c8e36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) + extends Source { + override def sourceName: String = "mesos_cluster" + override def metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getQueuedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("launchedDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getLaunchedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("retryDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getPendingRetryDriversSize + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d9d62b0e287ed..8346a2407489f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -18,23 +18,19 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections +import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, - ExecutorInfo => MesosExecutorInfo, _} - +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -47,14 +43,7 @@ private[spark] class MesosSchedulerBackend( master: String) extends SchedulerBackend with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null + with MesosSchedulerUtils { // Which slave IDs we have executors on val slaveIdsWithExecutors = new HashSet[String] @@ -73,26 +62,9 @@ private[spark] class MesosSchedulerBackend( @volatile var appId: String = _ override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + classLoader = Thread.currentThread.getContextClassLoader + startScheduler(master, MesosSchedulerBackend.this, fwInfo) } def createExecutorInfo(execId: String): MesosExecutorInfo = { @@ -125,17 +97,19 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.get("spark.executor.uri", null) + val uri = sc.conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val executorBackendName = classOf[MesosExecutorBackend].getName - if (uri == null) { + if (uri.isEmpty) { val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } val cpus = Resource.newBuilder() .setName("cpus") @@ -181,18 +155,7 @@ private[spark] class MesosSchedulerBackend( inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } + markRegistered() } } @@ -287,14 +250,6 @@ private[spark] class MesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Turn a Spark TaskDescription into a Mesos task */ def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -339,13 +294,13 @@ private[spark] class MesosSchedulerBackend( } override def stop() { - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def reviveOffers() { - driver.reviveOffers() + mesosDriver.reviveOffers() } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} @@ -380,7 +335,7 @@ private[spark] class MesosSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - driver.killTask( + mesosDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala new file mode 100644 index 0000000000000..d11228f3d016a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.List +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConversions._ + +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} +import org.apache.mesos.{MesosSchedulerDriver, Scheduler} +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Shared trait for implementing a Mesos Scheduler. This holds common state and helper + * methods and Mesos scheduler will use. + */ +private[mesos] trait MesosSchedulerUtils extends Logging { + // Lock used to wait for scheduler to be registered + private final val registerLatch = new CountDownLatch(1) + + // Driver for talking to Mesos + protected var mesosDriver: MesosSchedulerDriver = null + + /** + * Starts the MesosSchedulerDriver with the provided information. This method returns + * only after the scheduler has registered with Mesos. + * @param masterUrl Mesos master connection URL + * @param scheduler Scheduler object + * @param fwInfo FrameworkInfo to pass to the Mesos master + */ + def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + synchronized { + if (mesosDriver != null) { + registerLatch.await() + return + } + + new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { + setDaemon(true) + + override def run() { + mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + try { + val ret = mesosDriver.run() + logInfo("driver.run() returned with code " + ret) + if (ret.equals(Status.DRIVER_ABORTED)) { + System.exit(1) + } + } catch { + case e: Exception => { + logError("driver.run() failed", e) + System.exit(1) + } + } + } + }.start() + + registerLatch.await() + } + } + + /** + * Signal that the scheduler has registered with Mesos. + */ + protected def markRegistered(): Unit = { + registerLatch.countDown() + } + + /** + * Get the amount of resources for the specified type from the resource list + */ + protected def getResource(res: List[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + 0.0 + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index ac5b524517818..e64d06c4d3cfc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -123,7 +123,7 @@ private[spark] class LocalBackend( } override def stop() { - localEndpoint.sendWithReply(StopExecutor) + localEndpoint.ask(StopExecutor) } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 579fb6624e692..754832b8a4ca7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,16 +49,17 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) - if (bufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + - s"2048 mb, got: + $bufferSizeMb mb.") + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + $bufferSizeKb mb.") } - private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + private val bufferSize = (bufferSizeKb * 1024).toInt - val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 @@ -173,7 +174,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max.mb value.") + "increase spark.kryoserializer.buffer.max value.") } ByteBuffer.wrap(output.toBytes) } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb992579655..5abfa467c0ec8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -23,6 +23,7 @@ import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 538e150ead05a..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -78,7 +78,8 @@ class FileShuffleBlockManager(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 7a2c5ae32d98b..80374adc44296 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -79,7 +79,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockManager, blocksByAddress, serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index c798843bd5d8a..9bfc4201d37c0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -55,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = driverEndpoint.askWithReply[Boolean]( + val res = driverEndpoint.askWithRetry[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -63,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -81,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -93,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) + driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -110,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -122,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -141,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus) } /** @@ -166,7 +166,7 @@ class BlockManagerMaster( * master endpoint for a response to a prior message. */ val response = driverEndpoint. - askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -190,7 +190,7 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } @@ -205,7 +205,7 @@ class BlockManagerMaster( /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!driverEndpoint.askWithReply[Boolean](message)) { + if (!driverEndpoint.askWithRetry[Boolean](message)) { throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 4682167912ff0..7212362df5d71 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -132,7 +132,7 @@ class BlockManagerMasterEndpoint( val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.sendWithReply[Int](removeMsg) + bm.slaveEndpoint.ask[Int](removeMsg) }.toSeq ) } @@ -142,7 +142,7 @@ class BlockManagerMasterEndpoint( val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) + bm.slaveEndpoint.ask[Boolean](removeMsg) }.toSeq ) } @@ -159,7 +159,7 @@ class BlockManagerMasterEndpoint( } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveEndpoint.sendWithReply[Int](removeMsg) + bm.slaveEndpoint.ask[Int](removeMsg) }.toSeq ) } @@ -214,7 +214,7 @@ class BlockManagerMasterEndpoint( // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) + blockManager.get.slaveEndpoint.ask[Boolean](RemoveBlock(blockId)) } } } @@ -253,7 +253,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) + info.slaveEndpoint.ask[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -277,7 +277,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) + info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 4b232ae7d3180..1f45956282166 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 711a3697bda15..935c8a4f80e7b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.collection.OpenHashSet import scala.collection.mutable.HashMap -private[jobs] object UIData { +private[spark] object UIData { class ExecutorSummary { var taskTime : Long = 0 diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 342bc9a06db47..4b5a5df5ef7b7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1020,21 +1020,48 @@ private[spark] object Utils extends Logging { } /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + def byteStringAsBytes(str: String): Long = { + JavaUtils.byteStringAsBytes(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + def byteStringAsKb(str: String): Long = { + JavaUtils.byteStringAsKb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + def byteStringAsMb(str: String): Long = { + JavaUtils.byteStringAsMb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + def byteStringAsGb(str: String): Long = { + JavaUtils.byteStringAsGb(str) + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } + // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // is assumed to be bytes + (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** @@ -1272,16 +1299,18 @@ private[spark] object Utils extends Logging { } /** Default filtering function for finding call sites using `getCallSite`. */ - private def coreExclusionFunction(className: String): Boolean = { - // A regular expression to match classes of the "core" Spark API that we want to skip when - // finding the call site of a method. + private def sparkInternalExclusionFunction(className: String): Boolean = { + // A regular expression to match classes of the internal Spark API's + // that we want to skip when finding the call site of a method. val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r + val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r val SCALA_CORE_CLASS_PREFIX = "scala" - val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined + val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || + SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) // If the class is a Spark internal class or a Scala class, then exclude. - isSparkCoreClass || isScalaClass + isSparkClass || isScalaClass } /** @@ -1291,7 +1320,7 @@ private[spark] object Utils extends Logging { * * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { + def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): CallSite = { // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD // transformation, a SparkContext function (such as parallelize), or anything else that leads @@ -1330,9 +1359,17 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt - CallSite( - shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine", - longForm = callStack.take(callStackDepth).mkString("\n")) + val shortForm = + if (firstUserFile == "HiveSessionImpl.java") { + // To be more user friendly, show a nicer string for queries submitted from the JDBC + // server. + "Spark JDBC Server Query" + } else { + s"$lastSparkMethod at $firstUserFile:$firstUserLine" + } + val longForm = callStack.take(callStackDepth).mkString("\n") + + CallSite(shortForm, longForm) } /** Return a string containing part of a file from byte 'start' to 'end'. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 30dd7f22e494f..f912049563906 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,8 +89,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a695fb62086..4ed8a740f99db 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,7 +108,9 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. @@ -525,7 +527,8 @@ private[spark] class ExternalSorter[K, V, C]( val k = elem._1 var c = elem._2 while (sorted.hasNext && sorted.head._1 == k) { - c = mergeCombiners(c, sorted.head._2) + val pair = sorted.next() + c = mergeCombiners(c, pair._2) } (k, c) } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index e579421676343..7138b4b8e4533 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -138,7 +138,7 @@ private[spark] object RollingFileAppender { val STRATEGY_DEFAULT = "" val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8a4f2a08fe701..c2089b0e56a1f 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -157,11 +157,11 @@ public void sample() { public void randomSplit() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11); + JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); Assert.assertEquals(3, splits.length); - Assert.assertEquals(2, splits[0].count()); - Assert.assertEquals(3, splits[1].count()); - Assert.assertEquals(5, splits[2].count()); + Assert.assertEquals(1, splits[0].count()); + Assert.assertEquals(2, splits[1].count()); + Assert.assertEquals(7, splits[2].count()); } @Test @@ -1009,7 +1009,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 70529d9216591..668ddf9f5f0a9 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 97ea3578aa8ba..96a9c207ad022 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -77,7 +77,7 @@ class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { } test("groupByKey where map output sizes exceed maxMbInFlight") { - val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + val conf = new SparkConf().set("spark.reducer.maxSizeInFlight", "1m") sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 0fd570e5297d9..b789912e9ebef 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -48,7 +48,7 @@ class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithReply[HeartbeatResponse]( + val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( @@ -71,7 +71,7 @@ class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithReply[HeartbeatResponse]( + val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 4d3e09793faff..ae17fc60e4a43 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -141,6 +141,41 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } + test("inherited job group (SPARK-6629)") { + sc = new SparkContext("local[2]", "test") + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + sc.setJobGroup("jobA", "this is a job to be cancelled") + @volatile var exception: Exception = null + val jobA = new Thread() { + // The job group should be inherited by this thread + override def run(): Unit = { + exception = intercept[SparkException] { + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + } + } + } + jobA.start() + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + sc.cancelJobGroup("jobA") + jobA.join(10000) + assert(!jobA.isAlive) + assert(exception.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + val jobB = sc.parallelize(1 to 100, 2).countAsync() + assert(jobB.get() === 100) + } + test("job group with interruption") { sc = new SparkContext("local[2]", "test") diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 272e6af0514e4..68d08e32f9aa4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -24,11 +24,30 @@ import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { + test("Test byteString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + } + + test("Test timeString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + } + test("loading from system properties") { System.setProperty("spark.test.testProperty", "2") val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 728558a424780..9049db7755358 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -25,7 +25,9 @@ import com.google.common.io.Files import org.scalatest.FunSuite -import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.apache.spark.util.Utils import scala.concurrent.Await @@ -213,4 +215,63 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Comma separated paths for newAPIHadoopFile/wholeTextFiles/binaryFiles (SPARK-7155)") { + // Regression test for SPARK-7155 + // dir1 and dir2 are used for wholeTextFiles and binaryFiles + val dir1 = Utils.createTempDir() + val dir2 = Utils.createTempDir() + + val dirpath1=dir1.getAbsolutePath + val dirpath2=dir2.getAbsolutePath + + // file1 and file2 are placed inside dir1, they are also used for + // textFile, hadoopFile, and newAPIHadoopFile + // file3, file4 and file5 are placed inside dir2, they are used for + // textFile, hadoopFile, and newAPIHadoopFile as well + val file1 = new File(dir1, "part-00000") + val file2 = new File(dir1, "part-00001") + val file3 = new File(dir2, "part-00000") + val file4 = new File(dir2, "part-00001") + val file5 = new File(dir2, "part-00002") + + val filepath1=file1.getAbsolutePath + val filepath2=file2.getAbsolutePath + val filepath3=file3.getAbsolutePath + val filepath4=file4.getAbsolutePath + val filepath5=file5.getAbsolutePath + + + try { + // Create 5 text files. + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, UTF_8) + Files.write("someline1 in file2\nsomeline2 in file2", file2, UTF_8) + Files.write("someline1 in file3", file3, UTF_8) + Files.write("someline1 in file4\nsomeline2 in file4", file4, UTF_8) + Files.write("someline1 in file2\nsomeline2 in file5", file5, UTF_8) + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + // Test textFile, hadoopFile, and newAPIHadoopFile for file1 and file2 + assert(sc.textFile(filepath1 + "," + filepath2).count() == 5L) + assert(sc.hadoopFile(filepath1 + "," + filepath2, + classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + assert(sc.newAPIHadoopFile(filepath1 + "," + filepath2, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + + // Test textFile, hadoopFile, and newAPIHadoopFile for file3, file4, and file5 + assert(sc.textFile(filepath3 + "," + filepath4 + "," + filepath5).count() == 5L) + assert(sc.hadoopFile(filepath3 + "," + filepath4 + "," + filepath5, + classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + assert(sc.newAPIHadoopFile(filepath3 + "," + filepath4 + "," + filepath5, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + + // Test wholeTextFiles, and binaryFiles for dir1 and dir2 + assert(sc.wholeTextFiles(dirpath1 + "," + dirpath2).count() == 5L) + assert(sc.binaryFiles(dirpath1 + "," + dirpath2).count() == 5L) + + } finally { + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index b5383d553add1..10917c866cc7d 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark -import java.util.concurrent.Semaphore +import java.util.concurrent.{TimeUnit, Semaphore} import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.scheduler._ import org.scalatest.FunSuite /** @@ -189,4 +190,47 @@ class ThreadingSuite extends FunSuite with LocalSparkContext { assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } + + test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { + val jobStarted = new Semaphore(0) + val jobEnded = new Semaphore(0) + @volatile var jobResult: JobResult = null + + sc = new SparkContext("local", "test") + sc.setJobGroup("originalJobGroupId", "description") + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobStarted.release() + } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobResult = jobEnd.jobResult + jobEnded.release() + } + }) + + // Create a new thread which will inherit the current thread's properties + val thread = new Thread() { + override def run(): Unit = { + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") + // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task + try { + sc.parallelize(1 to 100).foreach { x => + Thread.sleep(100) + } + } catch { + case s: SparkException => // ignored so that we don't print noise in test logs + } + } + } + thread.start() + // Wait for the job to start, then mutate the original properties, which should have been + // inherited by the running job but hopefully defensively copied or snapshotted: + jobStarted.tryAcquire(10, TimeUnit.SECONDS) + sc.setJobGroup("modifiedJobGroupId", "description") + // Canceling the original job group should cancel the running job. In other words, the + // modification of the properties object should not affect the properties of running jobs + sc.cancelJobGroup("originalJobGroupId") + jobEnded.tryAcquire(10, TimeUnit.SECONDS) + assert(jobResult.isInstanceOf[JobFailed]) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala new file mode 100644 index 0000000000000..529f91e8eaf9e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{File, FileInputStream, FileOutputStream} +import java.nio.file.{Files, Path} +import java.util.jar.{JarEntry, JarOutputStream} + +import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} + +import com.google.common.io.ByteStreams + +import org.apache.commons.io.FileUtils + +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate + +private[deploy] object IvyTestUtils { + + /** + * Create the path for the jar and pom from the maven coordinate. Extension should be `jar` + * or `pom`. + */ + private def pathFromCoordinate( + artifact: MavenCoordinate, + prefix: Path, + ext: String, + useIvyLayout: Boolean): Path = { + val groupDirs = artifact.groupId.replace(".", File.separator) + val artifactDirs = artifact.artifactId + val artifactPath = + if (!useIvyLayout) { + Seq(groupDirs, artifactDirs, artifact.version).mkString(File.separator) + } else { + Seq(groupDirs, artifactDirs, artifact.version, ext + "s").mkString(File.separator) + } + new File(prefix.toFile, artifactPath).toPath + } + + private def artifactName(artifact: MavenCoordinate, ext: String = ".jar"): String = { + s"${artifact.artifactId}-${artifact.version}$ext" + } + + /** Write the contents to a file to the supplied directory. */ + private def writeFile(dir: File, fileName: String, contents: String): File = { + val outputFile = new File(dir, fileName) + val outputStream = new FileOutputStream(outputFile) + outputStream.write(contents.toCharArray.map(_.toByte)) + outputStream.close() + outputFile + } + + /** Create an example Python file. */ + private def createPythonFile(dir: File): File = { + val contents = + """def myfunc(x): + | return x + 1 + """.stripMargin + writeFile(dir, "mylib.py", contents) + } + + /** Create a simple testable Class. */ + private def createJavaClass(dir: File, className: String, packageName: String): File = { + val contents = + s"""package $packageName; + | + |import java.lang.Integer; + | + |class $className implements java.io.Serializable { + | + | public $className() {} + | + | public Integer myFunc(Integer x) { + | return x + 1; + | } + |} + """.stripMargin + val sourceFile = + new JavaSourceFromString(new File(dir, className + ".java").getAbsolutePath, contents) + createCompiledClass(className, dir, sourceFile, Seq.empty) + } + + /** Helper method to write artifact information in the pom. */ + private def pomArtifactWriter(artifact: MavenCoordinate, tabCount: Int = 1): String = { + var result = "\n" + " " * tabCount + s"${artifact.groupId}" + result += "\n" + " " * tabCount + s"${artifact.artifactId}" + result += "\n" + " " * tabCount + s"${artifact.version}" + result + } + + /** Create a pom file for this artifact. */ + private def createPom( + dir: File, + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]]): File = { + var content = """ + | + | + | 4.0.0 + """.stripMargin.trim + content += pomArtifactWriter(artifact) + content += dependencies.map { deps => + val inside = deps.map { dep => + "\t" + pomArtifactWriter(dep, 3) + "\n\t" + }.mkString("\n") + "\n \n" + inside + "\n " + }.getOrElse("") + content += "\n" + writeFile(dir, artifactName(artifact, ".pom"), content.trim) + } + + /** Create the jar for the given maven coordinate, using the supplied files. */ + private def packJar( + dir: File, + artifact: MavenCoordinate, + files: Seq[(String, File)]): File = { + val jarFile = new File(dir, artifactName(artifact)) + val jarFileStream = new FileOutputStream(jarFile) + val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) + + for (file <- files) { + val jarEntry = new JarEntry(file._1) + jarStream.putNextEntry(jarEntry) + + val in = new FileInputStream(file._2) + ByteStreams.copy(in, jarStream) + in.close() + } + jarStream.close() + jarFileStream.close() + + jarFile + } + + /** + * Creates a jar and pom file, mocking a Maven repository. The root path can be supplied with + * `tempDir`, dependencies can be created into the same repo, and python files can also be packed + * inside the jar. + * + * @param artifact The maven coordinate to generate the jar and pom for. + * @param dependencies List of dependencies this artifact might have to also create jars and poms. + * @param tempDir The root folder of the repository + * @param useIvyLayout whether to mock the Ivy layout for local repository testing + * @param withPython Whether to pack python files inside the jar for extensive testing. + * @return Root path of the repository + */ + private def createLocalRepository( + artifact: MavenCoordinate, + dependencies: Option[Seq[MavenCoordinate]] = None, + tempDir: Option[Path] = None, + useIvyLayout: Boolean = false, + withPython: Boolean = false): Path = { + // Where the root of the repository exists, and what Ivy will search in + val tempPath = tempDir.getOrElse(Files.createTempDirectory(null)) + // Create directory if it doesn't exist + Files.createDirectories(tempPath) + // Where to create temporary class files and such + val root = Files.createTempDirectory(tempPath, null).toFile + try { + val jarPath = pathFromCoordinate(artifact, tempPath, "jar", useIvyLayout) + Files.createDirectories(jarPath) + val className = "MyLib" + + val javaClass = createJavaClass(root, className, artifact.groupId) + // A tuple of files representation in the jar, and the file + val javaFile = (artifact.groupId.replace(".", "/") + "/" + javaClass.getName, javaClass) + val allFiles = + if (withPython) { + val pythonFile = createPythonFile(root) + Seq(javaFile, (pythonFile.getName, pythonFile)) + } else { + Seq(javaFile) + } + val jarFile = packJar(jarPath.toFile, artifact, allFiles) + assert(jarFile.exists(), "Problem creating Jar file") + val pomPath = pathFromCoordinate(artifact, tempPath, "pom", useIvyLayout) + Files.createDirectories(pomPath) + val pomFile = createPom(pomPath.toFile, artifact, dependencies) + assert(pomFile.exists(), "Problem creating Pom file") + } finally { + FileUtils.deleteDirectory(root) + } + tempPath + } + + /** + * Creates a suite of jars and poms, with or without dependencies, mocking a maven repository. + * @param artifact The main maven coordinate to generate the jar and pom for. + * @param dependencies List of dependencies this artifact might have to also create jars and poms. + * @param rootDir The root folder of the repository (like `~/.m2/repositories`) + * @param useIvyLayout whether to mock the Ivy layout for local repository testing + * @param withPython Whether to pack python files inside the jar for extensive testing. + * @return Root path of the repository. Will be `rootDir` if supplied. + */ + private[deploy] def createLocalRepositoryForTests( + artifact: MavenCoordinate, + dependencies: Option[String], + rootDir: Option[Path], + useIvyLayout: Boolean = false, + withPython: Boolean = false): Path = { + val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) + val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython) + deps.foreach { seq => seq.foreach { dep => + createLocalRepository(dep, None, Some(mainRepo), useIvyLayout, withPython = false) + }} + mainRepo + } + + /** + * Creates a repository for a test, and cleans it up afterwards. + * + * @param artifact The main maven coordinate to generate the jar and pom for. + * @param dependencies List of dependencies this artifact might have to also create jars and poms. + * @param rootDir The root folder of the repository (like `~/.m2/repositories`) + * @param useIvyLayout whether to mock the Ivy layout for local repository testing + * @param withPython Whether to pack python files inside the jar for extensive testing. + * @return Root path of the repository. Will be `rootDir` if supplied. + */ + private[deploy] def withRepository( + artifact: MavenCoordinate, + dependencies: Option[String], + rootDir: Option[Path], + useIvyLayout: Boolean = false, + withPython: Boolean = false)(f: String => Unit): Unit = { + val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, + withPython) + try { + f(repo.toUri.toString) + } finally { + // Clean up + if (repo.toString.contains(".m2") || repo.toString.contains(".ivy2")) { + FileUtils.deleteDirectory(new File(repo.toFile, + artifact.groupId.replace(".", File.separator) + File.separator + artifact.artifactId)) + dependencies.map(SparkSubmitUtils.extractMavenCoordinates).foreach { seq => + seq.foreach { dep => + FileUtils.deleteDirectory(new File(repo.toFile, + dep.artifactId.replace(".", File.separator))) + } + } + } else { + FileUtils.deleteDirectory(repo.toFile) + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4561e5b8e9663..8360b94599547 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch @@ -231,7 +232,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + mainClass should be ("org.apache.spark.deploy.rest.RestSubmissionClient") } else { childArgsStr should startWith ("--supervise --memory 4g --cores 5") childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" @@ -334,18 +335,22 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties runSparkSubmit(args) } - test("includes jars passed in through --packages") { + ignore("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) - val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" - val args = Seq( - "--class", JarCreationTest.getClass.getName.stripSuffix("$"), - "--name", "testApp", - "--master", "local-cluster[2,1,512]", - "--packages", packagesString, - "--conf", "spark.ui.enabled=false", - unusedJar.toString, - "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") - runSparkSubmit(args) + val main = MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") + IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", Seq(main, dep).mkString(","), + "--repositories", repo, + "--conf", "spark.ui.enabled=false", + unusedJar.toString, + "my.great.lib.MyLib", "my.great.dep.MyLib") + runSparkSubmit(args) + } } test("resolves command line argument paths correctly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8bcca926097a1..cc79ee7ea20b4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.deploy import java.io.{PrintStream, OutputStream, File} import scala.collection.mutable.ArrayBuffer - import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.ivy.core.module.descriptor.MDArtifact +import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver +import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate + class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { private val noOpOutputStream = new OutputStream { @@ -56,24 +58,23 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("create repo resolvers") { - val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + val settings = new IvySettings + val res1 = SparkSubmitUtils.createRepoResolvers(None, settings) // should have central and spark-packages by default - assert(resolver1.getResolvers.size() === 2) - assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") - assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + assert(res1.getResolvers.size() === 4) + assert(res1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "local-m2-cache") + assert(res1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "local-ivy-cache") + assert(res1.getResolvers.get(2).asInstanceOf[IBiblioResolver].getName === "central") + assert(res1.getResolvers.get(3).asInstanceOf[IBiblioResolver].getName === "spark-packages") val repos = "a/1,b/2,c/3" - val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) - assert(resolver2.getResolvers.size() === 5) + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos), settings) + assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => - if (i == 0) { - assert(resolver.getName === "central") - } else if (i == 1) { - assert(resolver.getName === "spark-packages") - } else { - assert(resolver.getName === s"repo-${i - 1}") - assert(resolver.getRoot === expected(i - 2)) + if (i > 3) { + assert(resolver.getName === s"repo-${i - 3}") + assert(resolver.getRoot === expected(i - 4)) } } } @@ -88,7 +89,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy/ivy" + val ivyPath = "dummy" + File.separator + "ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) @@ -97,17 +98,38 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { assert(index >= 0) jPaths = jPaths.substring(index + ivyPath.length) } - // end to end - val jarPath = SparkSubmitUtils.resolveMavenCoordinates( - "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") + IvyTestUtils.withRepository(main, None, None) { repo => + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), + Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") + } } - test("search for artifact at other repositories") { - val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", - Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) - assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + - "if package still exists. If it has been removed, replace the example in this test.") + test("search for artifact at local repositories") { + val main = new MavenCoordinate("my.awesome.lib", "mylib", "0.1") + // Local M2 repository + IvyTestUtils.withRepository(main, None, Some(SparkSubmitUtils.m2Path)) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + } + // Local Ivy Repository + val settings = new IvySettings + val ivyLocal = new File(settings.getDefaultIvyUserDir, "local" + File.separator) + IvyTestUtils.withRepository(main, None, Some(ivyLocal.toPath), true) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + } + // Local ivy repository with modified home + val dummyIvyPath = "dummy" + File.separator + "ivy" + val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) + IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal.toPath), true) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, + Some(dummyIvyPath), true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") + } } test("dependency not found throws RuntimeException") { @@ -126,11 +148,11 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) assert(path === "", "should return empty path") - // Should not exclude the following dependency. Will throw an error, because it doesn't exist, - // but the fact that it is checking means that it wasn't excluded. - intercept[RuntimeException] { - SparkSubmitUtils.resolveMavenCoordinates(coordinates + - ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true) + val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") + IvyTestUtils.withRepository(main, None, None) { repo => + val files = SparkSubmitUtils.resolveMavenCoordinates(coordinates + "," + main.toString, + Some(repo), None, true) + assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fcae603c7d18e..9e367a0d9af0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -224,9 +224,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers EventLoggingListener.initEventLog(new FileOutputStream(file)) } val writer = new OutputStreamWriter(bstream, "UTF-8") - try { + Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) - } finally { + } { writer.close() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 8e09976636386..0a318a27ac212 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,9 +39,9 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new StandaloneRestClient + private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None - private var server: Option[StandaloneRestServer] = None + private var server: Option[RestSubmissionServer] = None override def afterEach() { actorSystem.foreach(_.shutdown()) @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this - val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val response = RestSubmissionClient.run("app-resource", "main-class", appArgs, conf) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -208,7 +208,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val json = constructSubmitRequest(masterUrl).toJson val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" @@ -238,7 +238,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths, bad requests") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" val statusRequestPath = s"$httpUrl/$v/submissions/status" @@ -276,7 +276,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("bad request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") @@ -292,7 +292,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(code5 === HttpServletResponse.SC_BAD_REQUEST) assert(code6 === HttpServletResponse.SC_BAD_REQUEST) assert(code7 === HttpServletResponse.SC_BAD_REQUEST) - assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + assert(code8 === RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) // all responses should be error responses val errorResponse1 = getErrorResponse(response1) val errorResponse2 = getErrorResponse(response2) @@ -310,13 +310,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(errorResponse5.highestProtocolVersion === null) assert(errorResponse6.highestProtocolVersion === null) assert(errorResponse7.highestProtocolVersion === null) - assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + assert(errorResponse8.highestProtocolVersion === RestSubmissionServer.PROTOCOL_VERSION) } test("server returns unknown fields") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val oldJson = constructSubmitRequest(masterUrl).toJson val oldFields = parse(oldJson).asInstanceOf[JObject].obj @@ -340,7 +340,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" @@ -400,9 +400,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) val _server = if (faulty) { - new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } else { - new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } val port = _server.start() // set these to clean them up after every test @@ -563,20 +563,18 @@ private class SmarterMaster extends Actor { private class FaultyStandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { - protected override val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new MalformedSubmitServlet, - s"$baseContext/kill/*" -> new InvalidKillServlet, - s"$baseContext/status/*" -> new ExplodingStatusServlet, - "/*" -> new ErrorServlet - ) + protected override val submitRequestServlet = new MalformedSubmitServlet + protected override val killRequestServlet = new InvalidKillServlet + protected override val statusRequestServlet = new ExplodingStatusServlet /** A faulty servlet that produces malformed responses. */ - class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + class MalformedSubmitServlet + extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -586,7 +584,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -595,7 +593,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index aea76c1adcc09..85eb2a1d07ba4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 44c88b00c442a..ae3339d80f9c6 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -100,8 +100,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") - val reply = newRpcEndpointRef.askWithReply[String]("Echo") + val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) } @@ -115,7 +115,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } }) - val reply = rpcEndpointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } @@ -134,7 +134,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { - val reply = rpcEndpointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { anotherEnv.shutdown() @@ -162,7 +162,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { val e = intercept[Exception] { - rpcEndpointRef.askWithReply[String]("hello", 1 millis) + rpcEndpointRef.askWithRetry[String]("hello", 1 millis) } assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) } finally { @@ -399,7 +399,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val f = endpointRef.sendWithReply[String]("Hi") + val f = endpointRef.ask[String]("Hi") val ack = Await.result(f, 5 seconds) assert("ack" === ack) @@ -419,7 +419,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) assert("ack" === ack) } finally { @@ -437,7 +437,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val f = endpointRef.sendWithReply[String]("Hi") + val f = endpointRef.ask[String]("Hi") val e = intercept[SparkException] { Await.result(f, 5 seconds) } @@ -460,7 +460,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { Await.result(f, 5 seconds) } @@ -529,7 +529,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") intercept[TimeoutException] { Await.result(f, 1 seconds) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 057e226916027..83ae8701243e5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..f28e29e9b8d8e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import java.util.Date + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.{LocalSparkContext, SparkConf} + + +class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), null, null, null, null) + + test("can queue drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 967c9e9899c9d..da98d09184735 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -33,8 +33,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) @@ -43,8 +43,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index b070a54aa989b..1b13559e77cb8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -269,7 +269,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("serialization buffer overflow reporting") { import org.apache.spark.SparkException - val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" val largeObject = (1 to 1000000).toArray diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ffa5162a31841..f647200402ecb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7d82a7c66ad1a..f5b410f41da27 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -55,7 +55,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -356,7 +356,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - val reregister = !master.driverEndpoint.askWithReply[Boolean]( + val reregister = !master.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -814,14 +814,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) val diskBlockManager = new DiskBlockManager(blockManager, conf) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get - when(blockManager.conf).thenReturn(conf.clone.set(confKey, (1000 * 1000).toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val notMapped = diskStoreNotMapped.getBytes(blockId).get diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 37b593b2c5f79..2080c432d77db 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0), + new TaskContextImpl(0, 0, 0, 0, null), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ba99803f5a0e..62a3cbcdf69ea 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,6 @@ import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale -import java.util.PriorityQueue import scala.collection.mutable.ListBuffer import scala.util.Random @@ -35,6 +34,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.ByteUnit import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -65,6 +65,10 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("600l") + } + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -82,6 +86,100 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + test("Test byteString conversion") { + // Test zero + assert(Utils.byteStringAsBytes("0") === 0) + + assert(Utils.byteStringAsGb("1") === 1) + assert(Utils.byteStringAsGb("1g") === 1) + assert(Utils.byteStringAsGb("1023m") === 0) + assert(Utils.byteStringAsGb("1024m") === 1) + assert(Utils.byteStringAsGb("1048575k") === 0) + assert(Utils.byteStringAsGb("1048576k") === 1) + assert(Utils.byteStringAsGb("1k") === 0) + assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) + assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) + + assert(Utils.byteStringAsMb("1") === 1) + assert(Utils.byteStringAsMb("1m") === 1) + assert(Utils.byteStringAsMb("1048575b") === 0) + assert(Utils.byteStringAsMb("1048576b") === 1) + assert(Utils.byteStringAsMb("1023k") === 0) + assert(Utils.byteStringAsMb("1024k") === 1) + assert(Utils.byteStringAsMb("3645k") === 3) + assert(Utils.byteStringAsMb("1024gb") === 1048576) + assert(Utils.byteStringAsMb("1g") === ByteUnit.GiB.toMiB(1)) + assert(Utils.byteStringAsMb("1t") === ByteUnit.TiB.toMiB(1)) + assert(Utils.byteStringAsMb("1p") === ByteUnit.PiB.toMiB(1)) + + assert(Utils.byteStringAsKb("1") === 1) + assert(Utils.byteStringAsKb("1k") === 1) + assert(Utils.byteStringAsKb("1m") === ByteUnit.MiB.toKiB(1)) + assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) + assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) + assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) + + assert(Utils.byteStringAsBytes("1") === 1) + assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1g") === ByteUnit.GiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1t") === ByteUnit.TiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) + + // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes + // This demonstrates that we can have e.g 1024^3 PB without overflowing. + assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) + assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) + + // Run this to confirm it doesn't throw an exception + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to bytes + Utils.byteStringAsBytes("9223372036854775808") + } + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to TB + ByteUnit.PiB.toTiB(9223372036854775807L) + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064") + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064m") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("500ub") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600b") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("600gb This breaks") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This 123mb breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9ff067f86af44..de26aa351b0d2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -506,7 +506,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) + + // avoid combine before spill + sorter.insertAll((0 until 50000).iterator.map(i => (i , 2 * i))) + sorter.insertAll((0 until 50000).iterator.map(i => (i, 2 * i + 1))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) diff --git a/docs/configuration.md b/docs/configuration.md index d587b91124cb8..72105feba4919 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,6 +48,17 @@ The following format is accepted: 5d (days) 1y (years) + +Properties that specify a byte size should be configured with a unit of size. +The following format is accepted: + + 1b (bytes) + 1k or 1kb (kibibytes = 1024 bytes) + 1m or 1mb (mebibytes = 1024 kibibytes) + 1g or 1gb (gibibytes = 1024 mebibytes) + 1t or 1tb (tebibytes = 1024 gibibytes) + 1p or 1pb (pebibytes = 1024 tebibytes) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -272,12 +283,11 @@ Apart from these, the following properties are also available, and may be useful - spark.executor.logs.rolling.size.maxBytes + spark.executor.logs.rolling.maxSize (none) Set the max size of the file by which the executor logs will be rolled over. - Rolling is disabled by default. Value is set in terms of bytes. - See spark.executor.logs.rolling.maxRetainedFiles + Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -366,10 +376,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -403,10 +413,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -582,18 +592,18 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -641,19 +651,19 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -698,9 +708,9 @@ Apart from these, the following properties are also available, and may be useful - + @@ -816,9 +826,9 @@ Apart from these, the following properties are also available, and may be useful - + diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 963e88a3e1d8f..8d9c2ba2041b2 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -32,7 +32,7 @@ Resource allocation can be configured as follows, based on the cluster type: * **Standalone mode:** By default, applications submitted to the standalone mode cluster will run in FIFO (first-in-first-out) order, and each application will try to use all available nodes. You can limit the number of nodes an application uses by setting the `spark.cores.max` configuration property in it, - or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. + or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. Finally, in addition to controlling cores, each application's `spark.executor.memory` setting controls its memory use. * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 594bf78b67713..8f53d8201a089 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -78,6 +78,9 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. +Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -107,7 +110,11 @@ the `make-distribution.sh` script included in a Spark source tarball/checkout. The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. -The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: +## Client Mode + +In client mode, a Spark Mesos framework is launched directly on the client machine and waits for the driver output. + +The driver needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically @@ -129,8 +136,7 @@ val sc = new SparkContext(conf) {% endhighlight %} (You can also use [`spark-submit`](submitting-applications.html) and configure `spark.executor.uri` -in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file. Note -that `spark-submit` currently only supports deploying the Spark driver in `client` mode for Mesos.) +in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file.) When running a shell, the `spark.executor.uri` parameter is inherited from `SPARK_EXECUTOR_URI`, so it does not need to be redundantly passed in as a system property. @@ -139,6 +145,17 @@ it does not need to be redundantly passed in as a system property. ./bin/spark-shell --master mesos://host:5050 {% endhighlight %} +## Cluster mode + +Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client +can find the results of the driver from the Mesos Web UI. + +To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master url (e.g: mesos://host:5050). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url +to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +Spark cluster Web UI. # Mesos Run Modes diff --git a/docs/tuning.md b/docs/tuning.md index cbd227868b248..1cb223e74f382 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -60,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index eaf00d09f550d..46377a99c4857 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.param.Params; import org.apache.spark.ml.param.Params$; import org.apache.spark.mllib.linalg.BLAS; import org.apache.spark.mllib.linalg.Vector; @@ -100,11 +99,12 @@ public static void main(String[] args) throws Exception { /** * Example of defining a type of {@link Classifier}. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to + * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. + * However, this should still compile and run successfully. */ class MyJavaLogisticRegression - extends Classifier - implements Params { + extends Classifier { /** * Param for max number of iterations @@ -145,10 +145,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) /** * Example of defining a type of {@link ClassificationModel}. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to + * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. + * However, this should still compile and run successfully. */ class MyJavaLogisticRegressionModel - extends ClassificationModel implements Params { + extends ClassificationModel { private MyJavaLogisticRegression parent_; public MyJavaLogisticRegression parent() { return parent_; } diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 0000000000000..6ef188a220c51 --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 0bc36ea65e1ab..99588b0984ab2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -100,7 +100,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 5a9bd4214cf51..d7cf500577c2a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Integer => JInt} import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{Set => JSet} +import java.util.{List => JList} import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -30,6 +31,7 @@ import kafka.message.MessageAndMetadata import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -79,7 +81,7 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + val walEnabled = WriteAheadLogUtils.enableReceiverLog(ssc.conf) new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } @@ -234,7 +236,6 @@ object KafkaUtils { new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) } - /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -558,4 +559,94 @@ private class KafkaUtilsPythonHelper { topics, storageLevel) } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + import scala.collection.JavaConversions._ + val topicsFromOffsets = fromOffsets.keySet().map(_.topic) + if (topicsFromOffsets != topics.toSet) { + throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 8028e42ffb483..261402856ac5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -244,7 +244,7 @@ static String quoteForBatchScript(String arg) { boolean needsQuotes = false; for (int i = 0; i < arg.length(); i++) { int c = arg.codePointAt(i); - if (Character.isWhitespace(c) || c == '"' || c == '=') { + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { needsQuotes = true; break; } @@ -261,15 +261,14 @@ static String quoteForBatchScript(String arg) { quoted.append('"'); break; - case '=': - quoted.append('^'); - break; - default: break; } quoted.appendCodePoint(cp); } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } quoted.append("\""); return quoted.toString(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 206acfb514d86..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -101,12 +101,9 @@ public static void main(String[] argsArray) throws Exception { * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments * are "double quoted" (which is batch for escaping a quote). This page has more details about * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html - * - * The command is executed using "cmd /c" and formatted in single line, since that's the - * easiest way to consume this from a batch script (see spark-class2.cmd). */ private static String prepareWindowsCommand(List cmd, Map childEnv) { - StringBuilder cmdline = new StringBuilder("cmd /c \""); + StringBuilder cmdline = new StringBuilder(); for (Map.Entry e : childEnv.entrySet()) { cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); cmdline.append(" && "); @@ -115,7 +112,6 @@ private static String prepareWindowsCommand(List cmd, Map buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; } else if (className.startsWith("org.apache.spark.tools.")) { String sparkHome = getSparkHome(); File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index 1ae42eed8a3af..bc513ec9b3d10 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -74,7 +74,10 @@ public void testWindowsBatchQuoting() { assertEquals("\"a b c\"", quoteForBatchScript("a b c")); assertEquals("\"a \"\"b\"\" c\"", quoteForBatchScript("a \"b\" c")); assertEquals("\"a\"\"b\"\"c\"", quoteForBatchScript("a\"b\"c")); - assertEquals("\"ab^=\"\"cd\"\"\"", quoteForBatchScript("ab=\"cd\"")); + assertEquals("\"ab=\"\"cd\"\"\"", quoteForBatchScript("ab=\"cd\"")); + assertEquals("\"a,b,c\"", quoteForBatchScript("a,b,c")); + assertEquals("\"a;b;c\"", quoteForBatchScript("a;b;c")); + assertEquals("\"a,b,c\\\\\"", quoteForBatchScript("a,b,c\\")); } @Test diff --git a/mllib/pom.xml b/mllib/pom.xml index 5dfab36c76907..a3c57ae26000b 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -109,6 +109,21 @@ test-jar test + + org.jpmml + pmml-model + 1.1.15 + + + com.sun.xml.fastinfoset + FastInfoset + + + com.sun.istack + istack-commons-runtime + + + diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index d2ca2e6871e6b..8b4b5fd8af986 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 8eddf79cdfe28..6bfeecd764d75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.{Params, Param, ParamMap} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -86,6 +86,14 @@ class Pipeline extends Estimator[PipelineModel] { def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } def getStages: Array[PipelineStage] = getOrDefault(stages) + override def validate(paramMap: ParamMap): Unit = { + val map = extractParamMap(paramMap) + getStages.foreach { + case pStage: Params => pStage.validate(map) + case _ => + } + } + /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. @@ -140,7 +148,7 @@ class Pipeline extends Estimator[PipelineModel] { override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) val theStages = map(stages) - require(theStages.toSet.size == theStages.size, + require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) } @@ -157,6 +165,11 @@ class PipelineModel private[ml] ( private[ml] val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + override def validate(paramMap: ParamMap): Unit = { + val map = fittingParamMap ++ extractParamMap(paramMap) + stages.foreach(_.validate(map)) + } + /** * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input * estimator does not exist in the pipeline. @@ -168,7 +181,7 @@ class PipelineModel private[ml] ( } if (matched.isEmpty) { throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") - } else if (matched.size > 1) { + } else if (matched.length > 1) { throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") } else { matched.head.asInstanceOf[M] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d2e052fbbbf22..3d849867d4c47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -103,21 +103,16 @@ final class GBTClassifier */ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}") + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) setDefault(lossType -> "logistic") /** @group setParam */ - def setLossType(value: String): this.type = { - val lossStr = value.toLowerCase - require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" + - s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}") - set(lossType, lossStr) - this - } + def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ - def getLossType: String = getOrDefault(lossType) + def getLossType: String = getOrDefault(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index b20f2fc49a8f6..0b3128f9ee8cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.sql.types.DataType @@ -32,10 +32,14 @@ import org.apache.spark.sql.types.DataType class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { /** - * number of features + * Number of features. Should be > 0. + * (default = 2^18^) * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features") + val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", + ParamValidators.gt(0)) + + setDefault(numFeatures -> (1 << 18)) /** @group getParam */ def getNumFeatures: Int = getOrDefault(numFeatures) @@ -43,8 +47,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - setDefault(numFeatures -> (1 << 18)) - override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index decaeb0da6246..bd2b5f6067e2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{DoubleParam, ParamMap} +import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.sql.types.DataType @@ -32,10 +32,13 @@ import org.apache.spark.sql.types.DataType class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { /** - * Normalization in L^p^ space, p = 2 by default. + * Normalization in L^p^ space. Must be >= 1. + * (default: p = 2) * @group param */ - val p = new DoubleParam(this, "p", "the p norm value") + val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1)) + + setDefault(p -> 2.0) /** @group getParam */ def getP: Double = getOrDefault(p) @@ -43,8 +46,6 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { /** @group setParam */ def setP(value: Double): this.type = set(p, value) - setDefault(p -> 2.0) - override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { val normalizer = new feature.Normalizer(paramMap(p)) normalizer.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index d855f04799ae7..1b7c939c2dffe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap} import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -37,10 +37,13 @@ import org.apache.spark.sql.types.DataType class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { /** - * The polynomial degree to expand, which should be larger than 1. + * The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion. + * Default: 2 * @group param */ - val degree = new IntParam(this, "degree", "the polynomial degree to expand") + val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)", + ParamValidators.gt(1)) + setDefault(degree -> 2) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 447851ec034d6..a0e9ed32e0e4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -31,17 +31,19 @@ import org.apache.spark.sql.types.{StructField, StructType} * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { - + /** - * False by default. Centers the data with mean before scaling. + * Centers the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. + * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") /** - * True by default. Scales the data to unit standard deviation. + * Scales the data to unit standard deviation. + * Default: true * @group param */ val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") @@ -56,7 +58,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { setDefault(withMean -> false, withStd -> true) - + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 23956c512c8a6..9db3b29e10d69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputColName = map(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * :: AlphaComponent :: * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ @@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { val map = extractParamMap(paramMap) - val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val counts = dataset.select(col(map(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) Params.inheritValues(map, this, model) @@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + dataset.select(col("*"), + indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 376a004858b4c..01752ba482d0c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param} +import org.apache.spark.ml.param._ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** @@ -43,20 +43,20 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { /** * :: AlphaComponent :: * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) - * or using it to split the text (set matching to false). Optional parameters also allow to fold - * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length. + * or using it to split the text (set matching to false). Optional parameters also allow filtering + * tokens using a minimal length. * It returns an array of strings that can be empty. - * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true, - * lowercase = false, minTokenLength = 1 */ @AlphaComponent class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { /** - * param for minimum token length, default is one to avoid returning empty strings + * Minimum token length, >= 0. + * Default: 1, to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)", + ParamValidators.gtEq(0)) /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) @@ -65,7 +65,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getMinTokenLength: Int = getOrDefault(minTokenLength) /** - * param sets regex as splitting on gaps (true) or matching tokens (false) + * Indicates whether regex splits on gaps (true) or matching tokens (false). + * Default: false * @group param */ val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") @@ -77,7 +78,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getGaps: Boolean = getOrDefault(gaps) /** - * param sets regex pattern used by tokenizer + * Regex pattern used by tokenizer. + * Default: `"\\p{L}+|[^\\p{L}\\s]+"` * @group param */ val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 452faa06e2021..ed833c63c7ef1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, Attribute, AttributeGroup} -import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.{Row, DataFrame} @@ -37,17 +37,19 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu /** * Threshold for the number of values a categorical feature can take. * If a feature is found to have > maxCategories values, then it is declared continuous. + * Must be >= 2. * * (default = 20) */ val maxCategories = new IntParam(this, "maxCategories", - "Threshold for the number of values a categorical feature can take." + - " If a feature is found to have > maxCategories values, then it is declared continuous.") + "Threshold for the number of values a categorical feature can take (>= 2)." + + " If a feature is found to have > maxCategories values, then it is declared continuous.", + ParamValidators.gtEq(2)) + + setDefault(maxCategories -> 20) /** @group getParam */ def getMaxCategories: Int = getOrDefault(maxCategories) - - setDefault(maxCategories -> 20) } /** @@ -90,11 +92,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { /** @group setParam */ - def setMaxCategories(value: Int): this.type = { - require(value > 1, - s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.") - set(maxCategories, value) - } + def setMaxCategories(value: Int): this.type = set(maxCategories, value) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -233,6 +231,7 @@ private object VectorIndexer { * - Continuous features (columns) are left unchanged. * This also appends metadata to the output column, marking features as Numeric (continuous), * Nominal (categorical), or Binary (either continuous or categorical). + * Non-ML metadata is not carried over from the input to the output column. * * This maintains vector sparsity. * @@ -283,34 +282,40 @@ class VectorIndexerModel private[ml] ( // TODO: Check more carefully about whether this whole class will be included in a closure. + /** Per-vector transform function */ private val transformFunc: Vector => Vector = { - val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted val localVectorMap = categoryMaps - val f: Vector => Vector = { - case dv: DenseVector => - val tmpv = dv.copy - localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => - tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) - } - tmpv - case sv: SparseVector => - // We use the fact that categorical value 0 is always mapped to index 0. - val tmpv = sv.copy - var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices - var k = 0 // index into non-zero elements of sparse vector - while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { - val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) - if (featureIndex < tmpv.indices(k)) { - catFeatureIdx += 1 - } else if (featureIndex > tmpv.indices(k)) { - k += 1 - } else { - tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) - catFeatureIdx += 1 - k += 1 + val localNumFeatures = numFeatures + val f: Vector => Vector = { (v: Vector) => + assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length ${v.size}") + v match { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) } - } - tmpv + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCatFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCatFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } } f } @@ -326,13 +331,6 @@ class VectorIndexerModel private[ml] ( val map = extractParamMap(paramMap) val newField = prepOutputField(dataset.schema, map) val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) - // For now, just check the first row of inputCol for vector length. - val firstRow = dataset.select(map(inputCol)).take(1) - if (firstRow.length != 0) { - val actualNumFeatures = firstRow(0).getAs[Vector](0).size - require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + - s" $numFeatures but found length $actualNumFeatures") - } dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) } @@ -345,6 +343,7 @@ class VectorIndexerModel private[ml] ( s"VectorIndexerModel requires output column parameter: $outputCol") SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + // If the input metadata specifies numFeatures, compare with expected numFeatures. val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { Some(origAttrGroup.attributes.get.length) @@ -364,7 +363,7 @@ class VectorIndexerModel private[ml] ( * Prepare the output column field, including per-feature metadata. * @param schema Input schema * @param map Parameter map (with this class' embedded parameter map folded in) - * @return Output column field + * @return Output column field. This field does not contain non-ML metadata. */ private def prepOutputField(schema: StructType, map: ParamMap): StructField = { val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) @@ -391,6 +390,6 @@ class VectorIndexerModel private[ml] ( partialFeatureAttributes } val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) - newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + newAttributeGroup.toStructField() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala new file mode 100644 index 0000000000000..0163fa8bd8a5b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Params for [[Word2Vec]] and [[Word2VecModel]]. + */ +private[feature] trait Word2VecBase extends Params + with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { + + /** + * The dimension of the code that you want to transform from words. + */ + final val vectorSize = new IntParam( + this, "vectorSize", "the dimension of codes after transforming from words") + setDefault(vectorSize -> 100) + + /** @group getParam */ + def getVectorSize: Int = getOrDefault(vectorSize) + + /** + * Number of partitions for sentences of words. + */ + final val numPartitions = new IntParam( + this, "numPartitions", "number of partitions for sentences of words") + setDefault(numPartitions -> 1) + + /** @group getParam */ + def getNumPartitions: Int = getOrDefault(numPartitions) + + /** + * The minimum number of times a token must appear to be included in the word2vec model's + * vocabulary. + */ + final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + + "appear to be included in the word2vec model's vocabulary") + setDefault(minCount -> 5) + + /** @group getParam */ + def getMinCount: Int = getOrDefault(minCount) + + setDefault(stepSize -> 0.025) + setDefault(maxIter -> 1) + setDefault(seed -> 42L) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further + * natural language processing or machine learning process. + */ +@AlphaComponent +final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVectorSize(value: Int): this.type = set(vectorSize, value) + + /** @group setParam */ + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group setParam */ + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + def setMinCount(value: Int): this.type = set(minCount, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v } + val wordVectors = new feature.Word2Vec() + .setLearningRate(map(stepSize)) + .setMinCount(map(minCount)) + .setNumIterations(map(maxIter)) + .setNumPartitions(map(numPartitions)) + .setSeed(map(seed)) + .setVectorSize(map(vectorSize)) + .fit(input) + val model = new Word2VecModel(this, map, wordVectors) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[Word2Vec]]. + */ +@AlphaComponent +class Word2VecModel private[ml] ( + override val parent: Word2Vec, + override val fittingParamMap: ParamMap, + wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a sentence column to a vector column to represent the whole sentence. The transform + * is performed by averaging all word vectors it contains. + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val word2Vec = udf { sentence: Seq[String] => + if (sentence.size == 0) { + Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double]) + } else { + val cum = Vectors.zeros(map(vectorSize)) + val model = bWordVectors.value.getVectors + for (word <- sentence) { + if (model.contains(word)) { + axpy(1.0, bWordVectors.value.transform(word), cum) + } else { + // pass words which not belong to model + } + } + scal(1.0 / sentence.size, cum) + cum + } + } + dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index ab6281b9b2e34..fb770622e71f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -38,14 +38,15 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} private[ml] trait DecisionTreeParams extends PredictorParams { /** - * Maximum depth of the tree. + * Maximum depth of the tree (>= 0). * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * (default = 5) * @group param */ final val maxDepth: IntParam = - new IntParam(this, "maxDepth", "Maximum depth of the tree." + - " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", + ParamValidators.gtEq(0)) /** * Maximum number of bins used for discretizing continuous features and for choosing how to split @@ -56,7 +57,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { */ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + " discretizing continuous features. Must be >=2 and >= number of categories for any" + - " categorical feature.") + " categorical feature.", ParamValidators.gtEq(2)) /** * Minimum number of instances each child must have after split. @@ -69,7 +70,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + " number of instances each child must have after split. If a split causes the left or right" + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + - " Should be >= 1.") + " Should be >= 1.", ParamValidators.gtEq(1)) /** * Minimum information gain for a split to be considered at a tree node. @@ -85,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams { * @group expertParam */ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", - "Maximum memory in MB allocated to histogram aggregation.") + "Maximum memory in MB allocated to histogram aggregation.", + ParamValidators.gtEq(0)) /** * If false, the algorithm will pass trees to executors to match instances with nodes. @@ -111,34 +113,26 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.") + " checkpoint directory is set in the SparkContext. Must be >= 1.", + ParamValidators.gtEq(1)) setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) /** @group setParam */ - def setMaxDepth(value: Int): this.type = { - require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") - set(maxDepth, value) - } + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group getParam */ final def getMaxDepth: Int = getOrDefault(maxDepth) /** @group setParam */ - def setMaxBins(value: Int): this.type = { - require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") - set(maxBins, value) - } + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group getParam */ final def getMaxBins: Int = getOrDefault(maxBins) /** @group setParam */ - def setMinInstancesPerNode(value: Int): this.type = { - require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") - set(minInstancesPerNode, value) - } + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group getParam */ final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) @@ -150,10 +144,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final def getMinInfoGain: Double = getOrDefault(minInfoGain) /** @group expertSetParam */ - def setMaxMemoryInMB(value: Int): this.type = { - require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") - set(maxMemoryInMB, value) - } + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertGetParam */ final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) @@ -165,10 +156,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) /** @group expertSetParam */ - def setCheckpointInterval(value: Int): this.type = { - require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") - set(checkpointInterval, value) - } + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group expertGetParam */ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) @@ -209,21 +197,16 @@ private[ml] trait TreeClassifierParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) setDefault(impurity -> "gini") /** @group setParam */ - def setImpurity(value: String): this.type = { - val impurityStr = value.toLowerCase - require(TreeClassifierParams.supportedImpurities.contains(impurityStr), - s"Tree-based classifier was given unrecognized impurity: $value." + - s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") - set(impurity, impurityStr) - } + def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity).toLowerCase /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -256,21 +239,16 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) setDefault(impurity -> "variance") /** @group setParam */ - def setImpurity(value: String): this.type = { - val impurityStr = value.toLowerCase - require(TreeRegressorParams.supportedImpurities.contains(impurityStr), - s"Tree-based regressor was given unrecognized impurity: $value." + - s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") - set(impurity, impurityStr) - } + def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity).toLowerCase /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -299,21 +277,18 @@ private[ml] object TreeRegressorParams { private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** - * Fraction of the training data used for learning each decision tree. + * Fraction of the training data used for learning each decision tree, in range (0, 1]. * (default = 1.0) * @group param */ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", - "Fraction of the training data used for learning each decision tree.") + "Fraction of the training data used for learning each decision tree, in range (0, 1].", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) setDefault(subsamplingRate -> 1.0) /** @group setParam */ - def setSubsamplingRate(value: Double): this.type = { - require(value > 0.0 && value <= 1.0, - s"Subsampling rate must be in range (0,1]. Bad rate: $value") - set(subsamplingRate, value) - } + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group getParam */ final def getSubsamplingRate: Double = getOrDefault(subsamplingRate) @@ -350,7 +325,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { * (default = 20) * @group param */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)") + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) /** * The number of features to consider for splits at each tree node. @@ -378,30 +354,23 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + (value: String) => + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") /** @group setParam */ - def setNumTrees(value: Int): this.type = { - require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.") - set(numTrees, value) - } + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ final def getNumTrees: Int = getOrDefault(numTrees) /** @group setParam */ - def setFeatureSubsetStrategy(value: String): this.type = { - val strategyStr = value.toLowerCase - require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr), - s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" + - s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") - set(featureSubsetStrategy, strategyStr) - } + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy) + final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy).toLowerCase } private[ml] object RandomForestParams { @@ -426,7 +395,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { * @group param */ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator") + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -442,17 +412,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { setDefault(maxIter -> 20, stepSize -> 0.1) /** @group setParam */ - def setMaxIter(value: Int): this.type = { - require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.") - set(maxIter, value) - } + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setStepSize(value: Double): this.type = { - require(value > 0.0 && value <= 1.0, - s"GBT given invalid step size ($value). Value should be in (0,1].") - set(stepSize, value) - } + def setStepSize(value: Double): this.type = set(stepSize, value) /** @group getParam */ final def getStepSize: Double = getOrDefault(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ddc5907e7facd..df6360dce6013 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.Identifiable +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: @@ -34,10 +34,35 @@ import org.apache.spark.ml.Identifiable * @param parent parent object * @param name param name * @param doc documentation + * @param isValid optional validation method which indicates if a value is valid. + * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ @AlphaComponent -class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean) + extends Serializable { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue[T]) + + /** + * Assert that the given value is valid for this parameter. + * + * Note: Parameter checks involving interactions between multiple parameters should be + * implemented in [[Params.validate()]]. Checks for input/output columns should be + * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. + * + * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters + * should be specified via [[ParamPair]]. + * + * @throws IllegalArgumentException if the value is invalid + */ + private[param] def validate(value: T): Unit = { + if (!isValid(value)) { + throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." + + s" Parameter description: $toString") + } + } /** * Creates a param pair with the given value (for Java). @@ -65,38 +90,129 @@ class Param[T] (val parent: Params, val name: String, val doc: String) extends S } } +/** + * Factory methods for common validation functions for [[Param.isValid]]. + * The numerical methods only support Int, Long, Float, and Double. + */ +object ParamValidators { + + /** (private[param]) Default validation always return true */ + private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true + + /** + * Private method for checking numerical types and converting to Double. + * This is mainly for the sake of compilation; type checks are really handled + * by [[Params]] setters and the [[ParamPair]] constructor. + */ + private def getDouble[T](value: T): Double = value match { + case x: Int => x.toDouble + case x: Long => x.toDouble + case x: Float => x.toDouble + case x: Double => x.toDouble + case _ => + // The type should be checked before this is ever called. + throw new IllegalArgumentException("Numerical Param validation failed because" + + s" of unexpected input type: ${value.getClass}") + } + + /** Check if value > lowerBound */ + def gt[T](lowerBound: Double): T => Boolean = { (value: T) => + getDouble(value) > lowerBound + } + + /** Check if value >= lowerBound */ + def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) => + getDouble(value) >= lowerBound + } + + /** Check if value < upperBound */ + def lt[T](upperBound: Double): T => Boolean = { (value: T) => + getDouble(value) < upperBound + } + + /** Check if value <= upperBound */ + def ltEq[T](upperBound: Double): T => Boolean = { (value: T) => + getDouble(value) <= upperBound + } + + /** + * Check for value in range lowerBound to upperBound. + * @param lowerInclusive If true, check for value >= lowerBound. + * If false, check for value > lowerBound. + * @param upperInclusive If true, check for value <= upperBound. + * If false, check for value < upperBound. + */ + def inRange[T]( + lowerBound: Double, + upperBound: Double, + lowerInclusive: Boolean, + upperInclusive: Boolean): T => Boolean = { (value: T) => + val x: Double = getDouble(value) + val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound + val upperValid = if (upperInclusive) x <= upperBound else x < upperBound + lowerValid && upperValid + } + + /** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */ + def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = { + inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true) + } + + /** Check for value in an allowed set of values. */ + def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) => + allowed.contains(value) + } + + /** Check for value in an allowed set of values. */ + def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => + allowed.contains(value) + } +} + // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String) - extends Param[Double](parent, name, doc) { +class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean) + extends Param[Double](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String) - extends Param[Int](parent, name, doc) { +class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean) + extends Param[Int](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String) - extends Param[Float](parent, name, doc) { +class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean) + extends Param[Float](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String) - extends Param[Long](parent, name, doc) { +class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean) + extends Param[Long](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String) +class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) @@ -105,7 +221,11 @@ class BooleanParam(parent: Params, name: String, doc: String) /** * A param amd its value. */ -case class ParamPair[T](param: Param[T], value: T) +case class ParamPair[T](param: Param[T], value: T) { + // This is *the* place Param.validate is called. Whenever a parameter is specified, we should + // always construct a ParamPair so that validate is called. + param.validate(value) +} /** * :: AlphaComponent :: @@ -132,12 +252,22 @@ trait Params extends Identifiable with Serializable { /** * Validates parameter values stored internally plus the input parameter map. * Raises an exception if any parameter is invalid. + * + * This only needs to check for interactions between parameters. + * Parameter value checks which do not depend on other parameters are handled by + * [[Param.validate()]]. This method does not handle input/output column parameters; + * those are checked during schema validation. */ - def validate(paramMap: ParamMap): Unit = {} + def validate(paramMap: ParamMap): Unit = { } /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. + * + * This only needs to check for interactions between parameters. + * Parameter value checks which do not depend on other parameters are handled by + * [[Param.validate()]]. This method does not handle input/output column parameters; + * those are checked during schema validation. */ def validate(): Unit = validate(ParamMap.empty) @@ -221,6 +351,10 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. + * + * Note: Java developers should use the single-parameter [[setDefault()]]. + * Annotating this with varargs causes compilation failures. + * * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. @@ -305,6 +439,14 @@ private[spark] object Params { } } +/** + * Java-friendly wrapper for [[Params]]. + * Java developers who need to extend [[Params]] should use this class instead. + * If you need to extend a abstract class which already extends [[Params]], then that abstract + * class should be Java-friendly as well. + */ +abstract class JavaParams extends Params + /** * :: AlphaComponent :: * A param to value map. @@ -313,6 +455,12 @@ private[spark] object Params { final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { + /* DEVELOPERS: About validating parameter values + * This and ParamPair are the only two collections of parameters. + * This class should always create ParamPairs when + * specifying new parameter values. ParamPair will then call Param.validate(). + */ + /** * Creates an empty param map. */ @@ -321,10 +469,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Puts a (param, value) pair (overwrites if the input param exists). */ - def put[T](param: Param[T], value: T): this.type = { - map(param.asInstanceOf[Param[Any]]) = value - this - } + def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value)) /** * Puts a list of param pairs (overwrites if the input params exists). @@ -332,7 +477,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) @varargs def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => - put(p.param.asInstanceOf[Param[Any]], p.value) + map(p.param.asInstanceOf[Param[Any]]) = p.value } this } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e88c48741e99f..7da4bb4b4bf25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -21,6 +21,8 @@ import java.io.PrintWriter import scala.reflect.ClassTag +import org.apache.spark.ml.param.ParamValidators + /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with * {{{ @@ -31,8 +33,10 @@ private[shared] object SharedParamsCodeGen { def main(args: Array[String]): Unit = { val params = Seq( - ParamDesc[Double]("regParam", "regularization parameter"), - ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[Double]("regParam", "regularization parameter (>= 0)", + isValid = "ParamValidators.gtEq(0)"), + ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), @@ -40,13 +44,21 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "column name for predicted class conditional probabilities", Some("\"probability\"")), - ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[Double]("threshold", + "threshold in binary classification prediction, in range [0, 1]", + isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), - ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", + isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + isValid = "ParamValidators.inRange(0, 1)"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" @@ -59,7 +71,8 @@ private[shared] object SharedParamsCodeGen { private case class ParamDesc[T: ClassTag]( name: String, doc: String, - defaultValueStr: Option[String] = None) { + defaultValueStr: Option[String] = None, + isValid: String = "") { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc @@ -110,20 +123,23 @@ private[shared] object SharedParamsCodeGen { | setDefault($name, $v) |""".stripMargin }.getOrElse("") + val isValid = if (param.isValid != "") { + ", " + param.isValid + } else { + "" + } s""" |/** - | * :: DeveloperApi :: - | * Trait for shared param $name$defaultValueDoc. + | * (private[ml]) Trait for shared param $name$defaultValueDoc. | */ - |@DeveloperApi - |trait Has$Name extends Params { + |private[ml] trait Has$Name extends Params { | | /** | * Param for $doc. | * @group param | */ - | final val $name: $Param = new $Param(this, "$name", "$doc") + | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group getParam */ | final def get$Name: $T = getOrDefault($name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a860b8834cff9..e1549f46a68d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -26,45 +26,39 @@ import org.apache.spark.util.Utils // scalastyle:off /** - * :: DeveloperApi :: - * Trait for shared param regParam. + * (private[ml]) Trait for shared param regParam. */ -@DeveloperApi -trait HasRegParam extends Params { +private[ml] trait HasRegParam extends Params { /** - * Param for regularization parameter. + * Param for regularization parameter (>= 0). * @group param */ - final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getRegParam: Double = getOrDefault(regParam) } /** - * :: DeveloperApi :: - * Trait for shared param maxIter. + * (private[ml]) Trait for shared param maxIter. */ -@DeveloperApi -trait HasMaxIter extends Params { +private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations. + * Param for max number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = getOrDefault(maxIter) } /** - * :: DeveloperApi :: - * Trait for shared param featuresCol (default: "features"). + * (private[ml]) Trait for shared param featuresCol (default: "features"). */ -@DeveloperApi -trait HasFeaturesCol extends Params { +private[ml] trait HasFeaturesCol extends Params { /** * Param for features column name. @@ -79,11 +73,9 @@ trait HasFeaturesCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param labelCol (default: "label"). + * (private[ml]) Trait for shared param labelCol (default: "label"). */ -@DeveloperApi -trait HasLabelCol extends Params { +private[ml] trait HasLabelCol extends Params { /** * Param for label column name. @@ -98,11 +90,9 @@ trait HasLabelCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param predictionCol (default: "prediction"). + * (private[ml]) Trait for shared param predictionCol (default: "prediction"). */ -@DeveloperApi -trait HasPredictionCol extends Params { +private[ml] trait HasPredictionCol extends Params { /** * Param for prediction column name. @@ -117,11 +107,9 @@ trait HasPredictionCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param rawPredictionCol (default: "rawPrediction"). + * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction"). */ -@DeveloperApi -trait HasRawPredictionCol extends Params { +private[ml] trait HasRawPredictionCol extends Params { /** * Param for raw prediction (a.k.a. confidence) column name. @@ -136,11 +124,9 @@ trait HasRawPredictionCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param probabilityCol (default: "probability"). + * (private[ml]) Trait for shared param probabilityCol (default: "probability"). */ -@DeveloperApi -trait HasProbabilityCol extends Params { +private[ml] trait HasProbabilityCol extends Params { /** * Param for column name for predicted class conditional probabilities. @@ -155,28 +141,24 @@ trait HasProbabilityCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param threshold. + * (private[ml]) Trait for shared param threshold. */ -@DeveloperApi -trait HasThreshold extends Params { +private[ml] trait HasThreshold extends Params { /** - * Param for threshold in binary classification prediction. + * Param for threshold in binary classification prediction, in range [0, 1]. * @group param */ - final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getThreshold: Double = getOrDefault(threshold) } /** - * :: DeveloperApi :: - * Trait for shared param inputCol. + * (private[ml]) Trait for shared param inputCol. */ -@DeveloperApi -trait HasInputCol extends Params { +private[ml] trait HasInputCol extends Params { /** * Param for input column name. @@ -189,11 +171,9 @@ trait HasInputCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param inputCols. + * (private[ml]) Trait for shared param inputCols. */ -@DeveloperApi -trait HasInputCols extends Params { +private[ml] trait HasInputCols extends Params { /** * Param for input column names. @@ -206,11 +186,9 @@ trait HasInputCols extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param outputCol. + * (private[ml]) Trait for shared param outputCol. */ -@DeveloperApi -trait HasOutputCol extends Params { +private[ml] trait HasOutputCol extends Params { /** * Param for output column name. @@ -223,28 +201,24 @@ trait HasOutputCol extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param checkpointInterval. + * (private[ml]) Trait for shared param checkpointInterval. */ -@DeveloperApi -trait HasCheckpointInterval extends Params { +private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval. + * Param for checkpoint interval (>= 1). * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1)) /** @group getParam */ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) } /** - * :: DeveloperApi :: - * Trait for shared param fitIntercept (default: true). + * (private[ml]) Trait for shared param fitIntercept (default: true). */ -@DeveloperApi -trait HasFitIntercept extends Params { +private[ml] trait HasFitIntercept extends Params { /** * Param for whether to fit an intercept term. @@ -259,11 +233,9 @@ trait HasFitIntercept extends Params { } /** - * :: DeveloperApi :: - * Trait for shared param seed (default: Utils.random.nextLong()). + * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()). */ -@DeveloperApi -trait HasSeed extends Params { +private[ml] trait HasSeed extends Params { /** * Param for random seed. @@ -276,4 +248,49 @@ trait HasSeed extends Params { /** @group getParam */ final def getSeed: Long = getOrDefault(seed) } + +/** + * (private[ml]) Trait for shared param elasticNetParam. + */ +private[ml] trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + + /** @group getParam */ + final def getElasticNetParam: Double = getOrDefault(elasticNetParam) +} + +/** + * (private[ml]) Trait for shared param tol. + */ +private[ml] trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = getOrDefault(tol) +} + +/** + * (private[ml]) Trait for shared param stepSize. + */ +private[ml] trait HasStepSize extends Params { + + /** + * Param for Step size to be used for each iteration of optimization.. + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.") + + /** @group getParam */ + final def getStepSize: Double = getOrDefault(stepSize) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index bd793beba35b6..f9f2b2764ddb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -52,35 +52,40 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR with HasPredictionCol with HasCheckpointInterval { /** - * Param for rank of the matrix factorization. + * Param for rank of the matrix factorization (>= 1). + * Default: 10 * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization") + val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1)) /** @group getParam */ def getRank: Int = getOrDefault(rank) /** - * Param for number of user blocks. + * Param for number of user blocks (>= 1). + * Default: 10 * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", + ParamValidators.gtEq(1)) /** @group getParam */ def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** - * Param for number of item blocks. + * Param for number of item blocks (>= 1). + * Default: 10 * @group param */ - val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks") + val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", + ParamValidators.gtEq(1)) /** @group getParam */ def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. + * Default: false * @group param */ val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") @@ -89,16 +94,19 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** - * Param for the alpha parameter in the implicit preference formulation. + * Param for the alpha parameter in the implicit preference formulation (>= 0). + * Default: 1.0 * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", + ParamValidators.gtEq(0)) /** @group getParam */ def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. + * Default: "user" * @group param */ val userCol = new Param[String](this, "userCol", "column name for user ids") @@ -108,6 +116,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Param for the column name for item ids. + * Default: "item" * @group param */ val itemCol = new Param[String](this, "itemCol", "column name for item ids") @@ -117,6 +126,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Param for the column name for ratings. + * Default: "rating" * @group param */ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") @@ -126,6 +136,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** * Param for whether to apply nonnegativity constraints. + * Default: false * @group param */ val nonnegative = new BooleanParam( @@ -136,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", - ratingCol -> "rating", nonnegative -> false) + ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) /** * Validates and transforms the input schema. @@ -281,10 +292,6 @@ class ALS extends Estimator[ALSModel] with ALSParams { this } - setMaxIter(20) - setRegParam(1.0) - setCheckpointInterval(10) - override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = extractParamMap(paramMap) val ratings = dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index c784cf39ed31a..76c98376930c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -102,21 +102,16 @@ final class GBTRegressor */ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}") + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) setDefault(lossType -> "squared") /** @group setParam */ - def setLossType(value: String): this.type = { - val lossStr = value.toLowerCase - require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" + - s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}") - set(lossType, lossStr) - this - } + def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ - def getLossType: String = getOrDefault(lossType) + def getLossType: String = getOrDefault(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 26ca7459c4fdf..0b81c48466be9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -17,59 +17,168 @@ package org.apache.spark.ml.regression +import scala.collection.mutable + +import breeze.linalg.{norm => brzNorm, DenseVector => BDV} +import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction} + import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.ml.param.shared.{HasTol, HasElasticNetParam, HasMaxIter, HasRegParam} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.StatCounter +import org.apache.spark.Logging /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends RegressorParams - with HasRegParam with HasMaxIter - + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** * :: AlphaComponent :: * * Linear regression. + * + * The learning objective is to minimize the squared error, with regularization. + * The specific squared error loss function used is: + * L = 1/2n ||A weights - y||^2^ + * + * This support multiple types of regularization: + * - none (a.k.a. ordinary least squares) + * - L2 (ridge regression) + * - L1 (Lasso) + * - L2 + L1 (elastic net) */ @AlphaComponent class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Logging { - setDefault(regParam -> 0.1, maxIter -> 100) - - /** @group setParam */ + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) - /** @group setParam */ + /** + * Set the maximal number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset, paramMap).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) { - oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + // If the yStd is zero, then the intercept is yMean with zero weights; + // as a result, training is not needed. + if (yStd == 0.0) { + logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + + s"and the intercept will be the mean of the label; as a result, training is not needed.") + if (handlePersistence) instances.unpersist() + return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean) } - // Train model - val lr = new LinearRegressionWithSGD() - lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) - val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) - if (handlePersistence) { - oldDataset.unpersist() + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = paramMap(regParam) / yStd + val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + } else { + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) } - lrm + + val initialWeights = Vectors.zeros(numFeatures) + val states = + optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + var state = states.next() + val lossHistory = mutable.ArrayBuilder.make[Double] + + while (states.hasNext) { + lossHistory += state.value + state = states.next() + } + lossHistory += state.value + + // The weights are trained in the scaled space; we're converting them back to + // the original space. + val weights = { + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < rawWeights.length) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + Vectors.dense(rawWeights) + } + + // The intercept in R's GLMNET is computed using closed form after the coefficients are + // converged. See the following discussion for detail. + // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) + if (handlePersistence) instances.unpersist() + + // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. + new LinearRegressionModel(this, paramMap, weights.compressed, intercept) } } @@ -88,7 +197,7 @@ class LinearRegressionModel private[ml] ( with LinearRegressionParams { override protected def predict(features: Vector): Double = { - BLAS.dot(features, weights) + intercept + dot(features, weights) + intercept } override protected def copy(): LinearRegressionModel = { @@ -97,3 +206,223 @@ class LinearRegressionModel private[ml] ( m } } + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * For improving the convergence rate during the optimization process, and also preventing against + * features with very large variances exerting an overly large influence during model training, + * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce + * the condition number, and then trains the model in scaled space but returns the weights in + * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache + * the standardized dataset since it will create a lot of overhead. As a result, we perform the + * scaling implicitly when we compute the objective function. The following is the mathematical + * derivation. + * + * Note that we don't deal with intercept by adding bias here, because the intercept + * can be computed using closed form after the coefficients are converged. + * See this discussion for detail. + * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + * + * The objective function in the scaled space is given by + * {{{ + * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, + * }}} + * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, + * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. + * + * This can be rewritten as + * {{{ + * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} + * + \bar{y} / \hat{y}||^2 + * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 + * }}} + * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is + * {{{ + * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. + * }}}, and diff is + * {{{ + * \sum_i w_i^\prime x_i - y / \hat{y} + offset + * }}} + * + * Note that the effective weights and offset don't depend on training dataset, + * so they can be precomputed. + * + * Now, the first derivative of the objective function in scaled space is + * {{{ + * \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * }}} + * However, ($x_i - \bar{x_i}$) will densify the computation, so it's not + * an ideal formula when the training dataset is sparse format. + * + * This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms + * in the end by keeping the sum of diff. The first derivative of total + * objective function from all the samples is + * {{{ + * \frac{\partial L}{\partial\w_i} = + * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i}) + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) + * }}}, + * where correction_i = - diffSum \bar{x_i}) / \hat{x_i} + * + * A simple math can show that diffSum is actually zero, so we don't even + * need to add the correction terms in the end. From the definition of diff, + * {{{ + * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) + * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y}) + * = 0 + * }}} + * + * As a result, the first derivative of the total objective function only depends on + * the training dataset, which can be easily computed in distributed fashion, and is + * sparse format friendly. + * {{{ + * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * }}}, + * + * @param weights The weights/coefficients corresponding to the features. + * @param labelStd The standard deviation value of the label. + * @param labelMean The mean value of the label. + * @param featuresStd The standard deviation values of the features. + * @param featuresMean The mean values of the features. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + while (i < weightsArray.length) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + } + + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + // regVal is the sum of weight squares for L2 regularization + val norm = brzNorm(weights, 2.0) + val regVal = 0.5 * effectiveL2regParam * norm * norm + + val loss = leastSquaresAggregator.loss + regVal + val gradient = leastSquaresAggregator.gradient + axpy(effectiveL2regParam, w, gradient) + + (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4bb4ed813c006..d1ad0893cd044 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -22,7 +22,7 @@ import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -61,10 +61,12 @@ private[ml] trait CrossValidatorParams extends Params { def getEvaluator: Evaluator = getOrDefault(evaluator) /** - * param for number of folds for cross validation + * Param for number of folds for cross validation. Must be >= 2. + * Default: 3 * @group param */ - val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") + val numFolds: IntParam = new IntParam(this, "numFolds", + "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2)) /** @group getParam */ def getNumFolds: Int = getOrDefault(numFolds) @@ -93,6 +95,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) + override def validate(paramMap: ParamMap): Unit = { + getEstimatorParamMaps.foreach { eMap => + getEstimator.validate(eMap ++ paramMap) + } + } + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = extractParamMap(paramMap) val schema = dataset.schema @@ -101,8 +109,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val est = map(estimator) val eval = map(evaluator) val epm = map(estimatorParamMaps) - val numModels = epm.size - val metrics = new Array[Double](epm.size) + val numModels = epm.length + val metrics = new Array[Double](epm.length) val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() @@ -148,6 +156,10 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { + override def validate(paramMap: ParamMap): Unit = { + bestModel.validate(paramMap) + } + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala rename to mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index a1d49095c24ac..8a56748ab0a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.util import java.util.UUID diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 057b628c6a586..bd2e9079ce1ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD @@ -46,7 +47,7 @@ class LogisticRegressionModel ( val numFeatures: Int, val numClasses: Int) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable - with Saveable { + with Saveable with PMMLExportable { if (numClasses == 2) { require(weights.size == numFeatures, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 52fb62dcff1b4..33104cf06c6ea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD @@ -36,7 +37,7 @@ class SVMModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable - with Saveable { + with Saveable with PMMLExportable { private var threshold: Option[Double] = Some(0.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index e4e411a3c8b42..ba228b11fcec3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext @@ -34,7 +35,8 @@ import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable { +class KMeansModel ( + val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { /** A Java-friendly constructor that takes an Iterable of Vectors. */ def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 166c00cff634d..188d1e542b5b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -85,7 +93,7 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. @@ -108,6 +116,40 @@ sealed trait Vector extends Serializable { * with type `Double`. */ private[spark] def foreachActive(f: (Int, Double) => Unit) + + /** + * Number of active entries. An "active entry" is an element which is explicitly stored, + * regardless of its value. Note that inactive entries have value 0. + */ + def numActives: Int + + /** + * Number of nonzero elements. This scans all active values and count nonzeros. + */ + def numNonzeros: Int + + /** + * Converts this vector to a sparse vector with all explicit zeros removed. + */ + def toSparse: SparseVector + + /** + * Converts this vector to a dense vector. + */ + def toDense: DenseVector = new DenseVector(this.toArray) + + /** + * Returns a vector in either dense or sparse format, whichever uses less storage. + */ + def compressed: Vector = { + val nnz = numNonzeros + // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. + if (1.5 * (nnz + 1.0) < size) { + toSparse + } else { + toDense + } + } } /** @@ -284,7 +326,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -317,7 +359,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +413,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +443,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +464,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +493,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -483,7 +525,7 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -493,7 +535,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +543,50 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } + + override def numActives: Int = size + + override def numNonzeros: Int = { + // same as values.count(_ != 0.0) but faster + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } } object DenseVector { @@ -522,8 +608,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -543,11 +629,11 @@ class SparseVector( new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +642,59 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } + + override def numActives: Int = values.length + + override def numNonzeros: Int = { + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + if (nnz == numActives) { + this + } else { + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0.0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } + } } object SparseVector { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 8bfa0d2b64995..240baeb5a158b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ef6eccd90711a..efedc112d380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -164,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -181,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala new file mode 100644 index 0000000000000..354e90f3eeaa6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml + +import java.io.{File, OutputStream, StringWriter} +import javax.xml.transform.stream.StreamResult + +import org.jpmml.model.JAXBUtil + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory + +/** + * Export model to the PMML format + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +trait PMMLExportable { + + /** + * Export the model to the stream result in PMML format + */ + private def toPMML(streamResult: StreamResult): Unit = { + val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) + JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) + } + + /** + * Export the model to a local file in PMML format + */ + def toPMML(localPath: String): Unit = { + toPMML(new StreamResult(new File(localPath))) + } + + /** + * Export the model to a directory on a distributed file system in PMML format + */ + def toPMML(sc: SparkContext, path: String): Unit = { + val pmml = toPMML() + sc.parallelize(Array(pmml), 1).saveAsTextFile(path) + } + + /** + * Export the model to the OutputStream in PMML format + */ + def toPMML(outputStream: OutputStream): Unit = { + toPMML(new StreamResult(outputStream)) + } + + /** + * Export the model to a String in PMML format + */ + def toPMML(): String = { + val writer = new StringWriter + toPMML(new StreamResult(writer)) + writer.toString + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala new file mode 100644 index 0000000000000..34b447584e521 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel + */ +private[mllib] class BinaryClassificationPMMLModelExport( + model : GeneralizedLinearModel, + description : String, + normalizationMethod : RegressionNormalizationMethodType, + threshold: Double) + extends PMMLModelExport { + + populateBinaryClassificationPMML() + + /** + * Export the input LogisticRegressionModel or SVMModel to PMML format. + */ + private def populateBinaryClassificationPMML(): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + var interceptNO = threshold + if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { + if (threshold <= 0) { + interceptNO = Double.MinValue + } else if (threshold >= 1) { + interceptNO = Double.MaxValue + } else { + interceptNO = -math.log(1 / threshold - 1) + } + } + val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.CLASSIFICATION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withNormalizationMethod(normalizationMethod) + .withRegressionTables(regressionTableYES, regressionTableNO) + + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // add target field + val targetField = FieldName.create("target") + dataDictionary + .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala new file mode 100644 index 0000000000000..1874786af0002 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel abstract class + */ +private[mllib] class GeneralizedLinearPMMLModelExport( + model: GeneralizedLinearModel, + description: String) + extends PMMLModelExport { + + populateGeneralizedLinearPMML(model) + + /** + * Export the input GeneralizedLinearModel model to PMML format. + */ + private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTable = new RegressionTable(model.intercept) + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.REGRESSION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withRegressionTables(regressionTable) + + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // for completeness add target field + val targetField = FieldName.create("target") + dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala new file mode 100644 index 0000000000000..069e7afc9fca0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.clustering.KMeansModel + +/** + * PMML Model Export for KMeansModel class + */ +private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ + + populateKMeansPMML(model) + + /** + * Export the input KMeansModel model to PMML format. + */ + private def populateKMeansPMML(model : KMeansModel): Unit = { + pmml.getHeader.setDescription("k-means clustering") + + if (model.clusterCenters.length > 0) { + val clusterCenter = model.clusterCenters(0) + val fields = new SArray[FieldName](clusterCenter.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val comparisonMeasure = new ComparisonMeasure() + .withKind(ComparisonMeasure.Kind.DISTANCE) + .withMeasure(new SquaredEuclidean()) + val clusteringModel = new ClusteringModel() + .withModelName("k-means") + .withMiningSchema(miningSchema) + .withComparisonMeasure(comparisonMeasure) + .withFunctionName(MiningFunctionType.CLUSTERING) + .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .withNumberOfClusters(model.clusterCenters.length) + + for (i <- 0 until clusterCenter.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + clusteringModel.withClusteringFields( + new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + } + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + for (i <- 0 until model.clusterCenters.length) { + val cluster = new Cluster() + .withName("cluster_" + i) + .withArray(new org.dmg.pmml.Array() + .withType(Array.Type.REAL) + .withN(clusterCenter.size) + .withValue(model.clusterCenters(i).toArray.mkString(" "))) + // we don't have the size of the single cluster but only the centroids (withValue) + // .withSize(value) + clusteringModel.withClusters(cluster) + } + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(clusteringModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala new file mode 100644 index 0000000000000..ebdeae50bb32f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import java.text.SimpleDateFormat +import java.util.Date + +import scala.beans.BeanProperty + +import org.dmg.pmml.{Application, Header, PMML, Timestamp} + +private[mllib] trait PMMLModelExport { + + /** + * Holder of the exported model in PMML format + */ + @BeanProperty + val pmml: PMML = new PMML + + setHeader(pmml) + + private def setHeader(pmml: PMML): Unit = { + val version = getClass.getPackage.getImplementationVersion + val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val timestamp = new Timestamp() + .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + val header = new Header() + .withApplication(app) + .withTimestamp(timestamp) + pmml.setHeader(header) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala new file mode 100644 index 0000000000000..c16e83d6a067d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionNormalizationMethodType + +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.regression.LassoModel +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.RidgeRegressionModel + +private[mllib] object PMMLModelExportFactory { + + /** + * Factory object to help creating the necessary PMMLModelExport implementation + * taking as input the machine learning model (for example KMeansModel). + */ + def createPMMLModelExport(model: Any): PMMLModelExport = { + model match { + case kmeans: KMeansModel => + new KMeansPMMLModelExport(kmeans) + case linear: LinearRegressionModel => + new GeneralizedLinearPMMLModelExport(linear, "linear regression") + case ridge: RidgeRegressionModel => + new GeneralizedLinearPMMLModelExport(ridge, "ridge regression") + case lasso: LassoModel => + new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") + case svm: SVMModel => + new BinaryClassificationPMMLModelExport( + svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm.getThreshold.getOrElse(0.0)) + case logistic: LogisticRegressionModel => + if (logistic.numClasses == 2) { + new BinaryClassificationPMMLModelExport( + logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT, + logistic.getThreshold.getOrElse(0.5)) + } else { + throw new IllegalArgumentException( + "PMML Export not supported for Multinomial Logistic Regression") + } + case _ => + throw new IllegalArgumentException( + "PMML Export not supported for model: " + model.getClass.getName) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 9fd60ff7a0c79..26be30ff9d6fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -225,7 +225,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - /* + /** * Scaling columns to unit variance as a heuristic to reduce the condition number: * * During the optimization process, the convergence (rate) depends on the condition number of diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index e8b03816573cf..4f482384f0f38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD @@ -34,7 +35,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable with Saveable { + with RegressionModel with Serializable with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 6fa7ad52a5b33..9453c4f66c216 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD @@ -34,7 +35,7 @@ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable - with Saveable { + with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 309f9af466457..e0c03d8180c7a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -35,7 +36,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable with Saveable { + with RegressionModel with Serializable with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala index 23b291eee070b..8a821d1b23bab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -101,7 +101,7 @@ private[stat] object PearsonCorrelation extends Correlation with Logging { Matrices.fromBreeze(cov) } - private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = { - math.abs(value) <= threshhold + private def closeToZero(value: Double, threshold: Double = 1e-12): Boolean = { + math.abs(value) <= threshold } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..1f779584dcffd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,8 +177,11 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) + true + } else { + false } timer.stop("init") @@ -265,6 +268,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c9d33787b0bb5..b1a4517344970 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -56,6 +56,10 @@ object LinearDataGenerator { } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -70,10 +74,45 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.length)(0.0), + Array.fill[Double](weights.length)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + while (i < v.length) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java new file mode 100644 index 0000000000000..e7df10dfa63ac --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; + +/** + * Test Param and related classes in Java + */ +public class JavaParamsSuite { + + private transient JavaSparkContext jsc; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaParamsSuite"); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testParams() { + JavaTestParams testParams = new JavaTestParams(); + Assert.assertEquals(testParams.getMyIntParam(), 1); + testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); + Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); + Assert.assertEquals(testParams.getMyStringParam(), "a"); + } + + @Test + public void testParamValidate() { + ParamValidators.gt(1.0); + ParamValidators.gtEq(1.0); + ParamValidators.lt(1.0); + ParamValidators.ltEq(1.0); + ParamValidators.inRange(0, 1, true, false); + ParamValidators.inRange(0, 1); + ParamValidators.inArray(Lists.newArrayList(0, 1, 3)); + ParamValidators.inArray(Lists.newArrayList("a", "b")); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java new file mode 100644 index 0000000000000..8abe575610d19 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param; + +import java.util.List; + +import com.google.common.collect.Lists; + +/** + * A subclass of Params for testing. + */ +public class JavaTestParams extends JavaParams { + + public IntParam myIntParam; + + public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); } + + public JavaTestParams setMyIntParam(int value) { + set(myIntParam, value); return this; + } + + public DoubleParam myDoubleParam; + + public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); } + + public JavaTestParams setMyDoubleParam(double value) { + set(myDoubleParam, value); return this; + } + + public Param myStringParam; + + public String getMyStringParam() { return (String)getOrDefault(myStringParam); } + + public JavaTestParams setMyStringParam(String value) { + set(myStringParam, value); return this; + } + + public JavaTestParams() { + myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); + myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param", + ParamValidators.inRange(0.0, 1.0)); + List validStrings = Lists.newArrayList("a", "b"); + myStringParam = new Param(this, "myStringParam", "this is a string param", + ParamValidators.inArray(validStrings)); + setDefault(myIntParam, 1); + setDefault(myDoubleParam, 0.5); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 00b5d094d82f1..b6939e5870410 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 1b261b2643854..38dc83b1241cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.util.TestingUtils import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { val model = vectorIndexer.fit(densePoints1) // vectors of length 3 model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work - intercept[IllegalArgumentException] { - model.transform(densePoints2) + intercept[SparkException] { + model.transform(densePoints2).collect() println("Did not throw error when fit, transform were called on vectors of different lengths") } intercept[SparkException] { @@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { // TODO: Once input features marked as categorical are handled correctly, check that here. } } - // Check that non-ML metadata are preserved. - TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..03ba86670d453 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class Word2VecSuite extends FunSuite with MLlibTestSparkContext { + + test("Word2Vec") { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val numOfWords = sentence.split(" ").size + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), + "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), + "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + ) + + val expected = doc.map { sentence => + Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) => + word1.zip(word2).map { case (v1, v2) => v1 + v2 } + ).map(_ / numOfWords)) + } + + val docDF = doc.zip(expected).toDF("text", "expected") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .fit(docDF) + + model.transform(docDF).select("result", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + } + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 88ea679eeaad5..f8852606abbf2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -26,14 +26,22 @@ class ParamsSuite extends FunSuite { import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations") + assert(maxIter.doc === "max number of iterations (>= 0)") assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)") + assert(!maxIter.isValid(-1)) + assert(maxIter.isValid(0)) + assert(maxIter.isValid(1)) solver.setMaxIter(5) - assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + assert(maxIter.toString === + "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") assert(inputCol.toString === "inputCol: input column name (undefined)") + + intercept[IllegalArgumentException] { + solver.setMaxIter(-1) + } } test("param pair") { @@ -47,6 +55,9 @@ class ParamsSuite extends FunSuite { assert(pair.param.eq(maxIter)) assert(pair.value === 5) } + intercept[IllegalArgumentException] { + val pair = maxIter -> -1 + } } test("param map") { @@ -59,6 +70,9 @@ class ParamsSuite extends FunSuite { map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) + intercept[IllegalArgumentException] { + map0.put(maxIter, -1) + } assert(!map0.contains(inputCol)) intercept[NoSuchElementException] { @@ -122,14 +136,57 @@ class ParamsSuite extends FunSuite { assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { - solver.validate(ParamMap(maxIter -> -10)) + ParamMap(maxIter -> -10) } - solver.setMaxIter(-10) intercept[IllegalArgumentException] { - solver.validate() + solver.setMaxIter(-10) } solver.clearMaxIter() assert(!solver.isSet(maxIter)) } + + test("ParamValidate") { + val alwaysTrue = ParamValidators.alwaysTrue[Int] + assert(alwaysTrue(1)) + + val gt1Int = ParamValidators.gt[Int](1) + assert(!gt1Int(1) && gt1Int(2)) + val gt1Double = ParamValidators.gt[Double](1) + assert(!gt1Double(1.0) && gt1Double(1.1)) + + val gtEq1Int = ParamValidators.gtEq[Int](1) + assert(!gtEq1Int(0) && gtEq1Int(1)) + val gtEq1Double = ParamValidators.gtEq[Double](1) + assert(!gtEq1Double(0.9) && gtEq1Double(1.0)) + + val lt1Int = ParamValidators.lt[Int](1) + assert(lt1Int(0) && !lt1Int(1)) + val lt1Double = ParamValidators.lt[Double](1) + assert(lt1Double(0.9) && !lt1Double(1.0)) + + val ltEq1Int = ParamValidators.ltEq[Int](1) + assert(ltEq1Int(1) && !ltEq1Int(2)) + val ltEq1Double = ParamValidators.ltEq[Double](1) + assert(ltEq1Double(1.0) && !ltEq1Double(1.1)) + + val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2) + assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) && + !inRange02IntInclusive(-1) && !inRange02IntInclusive(3)) + val inRange02IntExclusive = + ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false) + assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2)) + + val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2) + assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) && + inRange02DoubleInclusive(2) && + !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1)) + val inRange02DoubleExclusive = + ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false) + assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) && + !inRange02DoubleExclusive(2)) + + val inArray = ParamValidators.inArray[Int](Array(1, 2)) + assert(inArray(1) && inArray(2) && !inArray(0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 641b64b42a5e7..6f9c9cb5166ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -29,7 +29,7 @@ class TestParams extends Params with HasMaxIter with HasInputCol { override def validate(paramMap: ParamMap): Unit = { val m = extractParamMap(paramMap) - require(m(maxIter) >= 0) + // Note: maxIter is validated when it is set. require(m.contains(inputCol)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbb44c3e2dfc2..80323ef5201a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,47 +19,149 @@ package org.apache.spark.ml.regression import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext, DataFrame} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + /** + * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + * is the same as the one trained by R's glmnet package. The following instruction + * describes how to reproduce the data in R. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + */ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) } - test("linear regression: default params") { - val lr = new LinearRegression - assert(lr.getLabelCol == "label") - val model = lr.fit(dataset) - model.transform(dataset) - .select("label", "prediction") - .collect() - // Check defaults - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") + test("linear regression with intercept without regularization") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + /** + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * label <- as.numeric(data$V1) + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.300528 + * as.numeric.data.V2. 4.701024 + * as.numeric.data.V3. 7.198257 + */ + val interceptR = 6.298698 + val weightsR = Array(4.700706, 7.199082) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.311546 + * as.numeric.data.V2. 2.123522 + * as.numeric.data.V3. 4.605651 + */ + val interceptR = 6.243000 + val weightsR = Array(4.024821, 6.679841) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } - test("linear regression with setters") { - // Set params, train, and check as many as we can. - val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0) - val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter).get === 10) - assert(model.fittingParamMap.get(lr.regParam).get === 1.0) - - // Call fit() with new params, and check as many as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.getPredictionCol == "thePred") + test("linear regression with intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.328062 + * as.numeric.data.V2. 3.222034 + * as.numeric.data.V3. 4.926260 + */ + val interceptR = 5.269376 + val weightsR = Array(3.736216, 5.712356) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.324108 + * as.numeric.data.V2. 3.168435 + * as.numeric.data.V3. 5.200403 + */ + val interceptR = 5.696056 + val weightsR = Array(3.670489, 6.001122) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala deleted file mode 100644 index c44cb61b34171..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.util - -import org.apache.spark.ml.Transformer -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.MetadataBuilder -import org.scalatest.FunSuite - -private[ml] object TestingUtils extends FunSuite { - - /** - * Test whether unrelated metadata are preserved for this transformer. - * This attaches extra metadata to a column, transforms the column, and check to ensure the - * extra metadata have not changed. - * @param data Input dataset - * @param transformer Transformer to test - * @param inputCol Unique input column for Transformer. This must be the ONLY input column. - * @param outputCol Output column to test for metadata presence. - */ - def testPreserveMetadata( - data: DataFrame, - transformer: Transformer, - inputCol: String, - outputCol: String): Unit = { - // Create some fake metadata - val origMetadata = data.schema(inputCol).metadata - val metaKey = "__testPreserveMetadata__fake_key" - val metaValue = 12345 - assert(!origMetadata.contains(metaKey), - s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") - val newMetadata = - new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() - // Add metadata to the inputCol - val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) - // Transform, and ensure extra metadata was not affected - val transformed = transformer.transform(withMetadata) - val transMetadata = transformed.schema(outputCol).metadata - assert(transMetadata.contains(metaKey), - "Unit test with testPreserveMetadata failed; extra metadata key was not present.") - assert(transMetadata.getLong(metaKey) === metaValue, - "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + - s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 2839c4c289b2d..24755e9ff46fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -270,4 +270,48 @@ class VectorsSuite extends FunSuite { assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) } + + test("Vector numActive and numNonzeros") { + val dv = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv.numActives === 4) + assert(dv.numNonzeros === 2) + + val sv = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv.numActives === 3) + assert(sv.numNonzeros === 2) + } + + test("Vector toSparse and toDense") { + val dv0 = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv0.toDense === dv0) + val dv0s = dv0.toSparse + assert(dv0s.numActives === 2) + assert(dv0s === dv0) + + val sv0 = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv0.toDense === sv0) + val sv0s = sv0.toSparse + assert(sv0s.numActives === 2) + assert(sv0s === sv0) + } + + test("Vector.compressed") { + val dv0 = Vectors.dense(1.0, 2.0, 3.0, 0.0) + val dv0c = dv0.compressed.asInstanceOf[DenseVector] + assert(dv0c === dv0) + + val dv1 = Vectors.dense(0.0, 2.0, 0.0, 0.0) + val dv1c = dv1.compressed.asInstanceOf[SparseVector] + assert(dv1 === dv1c) + assert(dv1c.numActives === 1) + + val sv0 = Vectors.sparse(4, Array(1, 2), Array(2.0, 0.0)) + val sv0c = sv0.compressed.asInstanceOf[SparseVector] + assert(sv0 === sv0c) + assert(sv0c.numActives === 1) + + val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) + val sv1c = sv1.compressed.asInstanceOf[DenseVector] + assert(sv1 === sv1c) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..0b646cf1ce6c4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionNormalizationMethodType +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.util.LinearDataGenerator + +class BinaryClassificationPMMLModelExportSuite extends FunSuite { + + test("logistic regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + + // assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === logisticRegressionModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) + } + + test("linear SVM PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + + // assert that the PMML format is as expected + assert(svmModelExport.isInstanceOf[PMMLModelExport]) + val pmml = svmModelExport.getPmml + assert(pmml.getHeader.getDescription + === "linear SVM") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === svmModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure linear SVM has normalization method set to NONE + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..f9afbd888dfc5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} +import org.apache.spark.mllib.util.LinearDataGenerator + +class GeneralizedLinearPMMLModelExportSuite extends FunSuite { + + test("linear regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + // assert that the PMML format is as expected + assert(linearModelExport.isInstanceOf[PMMLModelExport]) + val pmml = linearModelExport.getPmml + assert(pmml.getHeader.getDescription === "linear regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === linearRegressionModel.weights.size + 1) + // This verifies that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === linearRegressionModel.weights.size) + } + + test("ridge regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + // assert that the PMML format is as expected + assert(ridgeModelExport.isInstanceOf[PMMLModelExport]) + val pmml = ridgeModelExport.getPmml + assert(pmml.getHeader.getDescription === "ridge regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === ridgeRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === ridgeRegressionModel.weights.size) + } + + test("lasso PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + // assert that the PMML format is as expected + assert(lassoModelExport.isInstanceOf[PMMLModelExport]) + val pmml = lassoModelExport.getPmml + assert(pmml.getHeader.getDescription === "lasso regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === lassoModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === lassoModel.weights.size) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..b985d0446d7b0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.ClusteringModel +import org.scalatest.FunSuite + +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors + +class KMeansPMMLModelExportSuite extends FunSuite { + + test("KMeansPMMLModelExport generate PMML format") { + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) + + // assert that the PMML format is as expected + assert(modelExport.isInstanceOf[PMMLModelExport]) + val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala new file mode 100644 index 0000000000000..f28a4ac8ad01f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} +import org.apache.spark.mllib.util.LinearDataGenerator + +class PMMLModelExportFactorySuite extends FunSuite { + + test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) + + assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) + } + + test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + + "LinearRegressionModel, RidgeRegressionModel or LassoModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + } + + test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + + "when passing a LogisticRegressionModel or SVMModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + val logisticRegressionModelExport = + PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + } + + test("PMMLModelExportFactory throw IllegalArgumentException " + + "when passing a Multinomial Logistic Regression") { + /** 3 classes, 2 features */ + val multiclassLogisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + numFeatures = 2, numClasses = 3) + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) + } + } + + test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { + val invalidModel = new Object + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(invalidModel) + } + } +} diff --git a/network/common/pom.xml b/network/common/pom.xml index 22c738bde6d42..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -95,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 0000000000000..36d655017fb0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index b6fbace509a0e..6b514aaa1290d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -126,7 +126,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -137,6 +137,21 @@ private static boolean isSymlink(File file) throws IOException { .put("d", TimeUnit.DAYS) .build(); + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for * internal use. If no suffix is provided a direct conversion is attempted. @@ -145,16 +160,14 @@ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); try { - String suffix; - long val; Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (m.matches()) { - val = Long.parseLong(m.group(1)); - suffix = m.group(2); - } else { + if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); @@ -164,7 +177,7 @@ private static long parseTimeString(String str, TimeUnit unit) { return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; throw new NumberFormatException(timeError + "\n" + e.getMessage()); @@ -186,5 +199,83 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } diff --git a/pom.xml b/pom.xml index 9fbce1d639d8b..c85c5feeaf383 100644 --- a/pom.xml +++ b/pom.xml @@ -97,6 +97,7 @@ sql/catalyst sql/core sql/hive + unsafe assembly external/twitter external/flume @@ -1082,7 +1083,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1105,7 +1106,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven @@ -1176,7 +1177,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} @@ -1189,7 +1190,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1215,6 +1216,7 @@ false false true + true false @@ -1260,17 +1262,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1287,7 +1289,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1305,7 +1307,27 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 @@ -1315,7 +1337,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1334,7 +1356,7 @@ org.codehaus.gmavenplus gmavenplus-plugin - 1.2 + 1.5 process-test-classes @@ -1359,7 +1381,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.2 + 2.3 false diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 967961c2bf5c3..bf343d4b7e40b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -76,6 +76,34 @@ object MimaExcludes { // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives") + ) ++ Seq( + // This `protected[sql]` method was removed in 1.3.1 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.SQLContext.checkAnalysis"), + // This `private[sql]` class was removed in 1.4.0: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.execution.AddExchange$"), + // These test support classes were moved out of src/main and into src/test: + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTestData$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.TestGroupWriteSupport") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 09b4976d10c26..b4431c7ee05b6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,11 +34,11 @@ object BuildCommons { val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", @@ -156,13 +156,15 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: Add Sql to mima checks - // TODO: remove launcher from this list after 1.3. - allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach { + // TODO: remove launcher from this list after 1.4.0 + allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, + networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } + /* Unsafe settings */ + enable(Unsafe.settings)(unsafe) + /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) @@ -216,6 +218,13 @@ object SparkBuild extends PomBuild { } +object Unsafe { + lazy val settings = Seq( + // This option is needed to suppress warnings from sun.misc.Unsafe usage + javacOptions in Compile += "-XDignore.symbol.file" + ) +} + object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } @@ -424,6 +433,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/unsafe"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) @@ -467,7 +477,7 @@ object Unidoc { "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", "ml", "ml.attribute", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", - "ml.tuning" + "ml.recommendation", "ml.regression", "ml.tuning" ), "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" @@ -496,6 +506,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index cc9a4cf8ba170..a57c0b3ae0d00 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -39,7 +39,8 @@ IntegerType, ByteType -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', + 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4759f5fe783ad..5908ebc990a56 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -237,7 +237,8 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() - PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... + PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\ + NativeMethodAccessorImpl.java:... >>> df.explain(True) == Parsed Logical Plan == @@ -425,7 +426,7 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. - >>> df.sample(False, 0.5, 97).count() + >>> df.sample(False, 0.5, 42).count() 1 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction @@ -433,6 +434,27 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + def randomSplit(self, weights, seed=None): + """Randomly splits this :class:`DataFrame` with the provided weights. + + :param weights: list of doubles as weights with which to split the DataFrame. Weights will + be normalized if they don't sum up to 1.0. + :param seed: The seed for sampling. + + >>> splits = df4.randomSplit([1.0, 2.0], 24) + >>> splits[0].count() + 1 + + >>> splits[1].count() + 3 + """ + for w in weights: + if w < 0.0: + raise ValueError("Weights must be positive. Found weight value: %s" % w) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] + @property def dtypes(self): """Returns all column names and their data types as a list. @@ -632,7 +654,8 @@ def __getattr__(self, name): [Row(age=2), Row(age=5)] """ if name not in self.columns: - raise AttributeError("No such column: %s" % name) + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) jc = self._jdf.apply(name) return Column(jc) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f48b7b5d10af7..555c2fa5e7071 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -54,7 +54,7 @@ def _(col): 'upper': 'Converts a string expression to upper case.', 'lower': 'Converts a string expression to upper case.', 'sqrt': 'Computes the square root of the specified float value.', - 'abs': 'Computes the absolutle value.', + 'abs': 'Computes the absolute value.', 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', @@ -103,8 +103,28 @@ def countDistinct(col, *cols): return Column(jc) +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. + + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + def sparkPartitionId(): - """Returns a column for partition ID of the Spark task. + """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. diff --git a/python/pyspark/sql/mathfunctions.py b/python/pyspark/sql/mathfunctions.py new file mode 100644 index 0000000000000..7dbcab8694293 --- /dev/null +++ b/python/pyspark/sql/mathfunctions.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collection of builtin math functions +""" + +from pyspark import SparkContext +from pyspark.sql.dataframe import Column + +__all__ = [] + + +def _create_unary_mathfunction(name, doc=""): + """ Create a unary mathfunction by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.mathfunctions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +def _create_binary_mathfunction(name, doc=""): + """ Create a binary mathfunction by name""" + def _(col1, col2): + sc = SparkContext._active_spark_context + # users might write ints for simplicity. This would throw an error on the JVM side. + if type(col1) is int: + col1 = col1 * 1.0 + if type(col2) is int: + col2 = col2 * 1.0 + jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1, + col2._jc if isinstance(col2, Column) else col2) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +# math functions are found under another object therefore, they need to be handled separately +_mathfunctions = { + 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + + '0.0 through pi.', + 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2.', + 'atan': 'Computes the tangent inverse of the given value.', + 'cbrt': 'Computes the cube-root of the given value.', + 'ceil': 'Computes the ceiling of the given value.', + 'cos': 'Computes the cosine of the given value.', + 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'exp': 'Computes the exponential of the given value.', + 'expm1': 'Computes the exponential of the given value minus one.', + 'floor': 'Computes the floor of the given value.', + 'log': 'Computes the natural logarithm of the given value.', + 'log10': 'Computes the logarithm of the given value in Base 10.', + 'log1p': 'Computes the natural logarithm of the given value plus one.', + 'rint': 'Returns the double value that is closest in value to the argument and' + + ' is equal to a mathematical integer.', + 'signum': 'Computes the signum of the given value.', + 'sin': 'Computes the sine of the given value.', + 'sinh': 'Computes the hyperbolic sine of the given value.', + 'tan': 'Computes the tangent of the given value.', + 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' + + 'measured in degrees.', + 'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' + + 'measured in radians.' +} + +# math functions that take two arguments as input +_binary_mathfunctions = { + 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + + 'polar coordinates (r, theta).', + 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', + 'pow': 'Returns the value of the first argument raised to the power of the second argument.' +} + +for _name, _doc in _mathfunctions.items(): + globals()[_name] = _create_unary_mathfunction(_name, _doc) +for _name, _doc in _binary_mathfunctions.items(): + globals()[_name] = _create_binary_mathfunction(_name, _doc) +del _name, _doc +__all__ += _mathfunctions.keys() +__all__ += _binary_mathfunctions.keys() +__all__.sort() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe43c374f1cb1..2ffd18ebd7c89 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -387,6 +387,35 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_math_functions(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + from pyspark.sql import mathfunctions as functions + import math + + def get_values(l): + return [j[0] for j in l] + + def assert_close(a, b): + c = get_values(b) + diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] + return sum(diff) == len(a) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos(df.a)).collect()) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos("a")).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df.a)).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df['a'])).collect()) + assert_close([math.pow(i, 2 * i) for i in range(10)], + df.select(functions.pow(df.a, df.b)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2.0)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot(df.a, df.b)).collect()) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8d610d6569b4a..e278b29003f69 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -17,11 +17,12 @@ from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -67,7 +68,104 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print(""" + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser) + return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def _printErrorMsg(sc): + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -85,8 +183,63 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version)) - raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self._topic = topic + self._partition = partition + self._fromOffset = fromOffset + self._untilOffset = untilOffset + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, + self._untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5fa1e5ef081ab..7c06c203455d9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import time import operator import tempfile +import random import struct from functools import reduce @@ -35,7 +36,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition class PySparkStreamingTestCase(unittest.TestCase): @@ -590,9 +591,27 @@ def tearDown(self): super(KafkaStreamTests, self).tearDown() + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + def test_kafka_stream(self): """Test the Python Kafka stream API.""" - topic = "topic1" + topic = self._randomTopic() sendData = {"a": 3, "b": 5, "c": 10} self._kafkaTestUtils.createTopic(topic) @@ -601,13 +620,64 @@ def test_kafka_stream(self): stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) - result = {} - for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), - sum(sendData.values()))): - result[i] = result.get(i, 0) + 1 + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} - self.assertEqual(sendData, result) + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) if __name__ == "__main__": unittest.main() diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh new file mode 100755 index 0000000000000..ef1fc573d5c65 --- /dev/null +++ b/sbin/start-mesos-dispatcher.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Starts the Mesos Cluster Dispatcher on the machine this script is executed on. +# The Mesos Cluster Dispatcher is responsible for launching the Mesos framework and +# Rest server to handle driver requests for Mesos cluster mode. +# Only one cluster dispatcher is needed per Mesos cluster. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then + SPARK_MESOS_DISPATCHER_PORT=7077 +fi + +if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then + SPARK_MESOS_DISPATCHER_HOST=`hostname` +fi + + +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh new file mode 100755 index 0000000000000..4fddcf7f95d40 --- /dev/null +++ b/sbin/start-shuffle-service.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the external shuffle server on the machine this script is executed on. +# +# Usage: start-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh new file mode 100755 index 0000000000000..cb65d95b5e524 --- /dev/null +++ b/sbin/stop-mesos-dispatcher.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Stop the Mesos Cluster dispatcher on the machine this script is executed on. + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +. "$sbin/spark-config.sh" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 + diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh new file mode 100755 index 0000000000000..4cb6891ae27fa --- /dev/null +++ b/sbin/stop-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 3dea2ee76542f..5c322d032d474 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,11 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java new file mode 100644 index 0000000000000..299ff3728a6d9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. + * + * This map supports a maximum of 2 billion keys. + */ +public final class UnsafeFixedWidthAggregationMap { + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private final long[] emptyAggregationBuffer; + + private final StructType aggregationBufferSchema; + + private final StructType groupingKeySchema; + + /** + * Encodes grouping keys as UnsafeRows. + */ + private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + + /** + * A hashmap which maps from opaque bytearray keys to bytearray values. + */ + private final BytesToBytesMap map; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + + /** + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + * + * By default, this is a 1MB array, but it will grow as necessary in case larger keys are + * encountered. + */ + private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; + + private final boolean enablePerfMetrics; + + /** + * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, + * false otherwise. + */ + public static boolean supportsGroupKeySchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * Create a new UnsafeFixedWidthAggregationMap. + * + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param memoryManager the memory manager used to allocate our Unsafe memory structures. + * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) + */ + public UnsafeFixedWidthAggregationMap( + Row emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this.emptyAggregationBuffer = + convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.enablePerfMetrics = enablePerfMetrics; + } + + /** + * Convert a Java object row into an UnsafeRow, allocating it into a new long array. + */ + private static long[] convertToUnsafeRow(Row javaRow, StructType schema) { + final UnsafeRowConverter converter = new UnsafeRowConverter(schema); + final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)]; + final long writtenLength = + converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET); + assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; + return unsafeRow; + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. + */ + public UnsafeRow getAggregationBuffer(Row groupingKey) { + final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + // Make sure that the buffer is large enough to hold the key. If it's not, grow it: + if (groupingKeySize > groupingKeyConversionScratchSpace.length) { + // This new array will be initially zero, so there's no need to zero it out here + groupingKeyConversionScratchSpace = new long[groupingKeySize]; + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero out + // the portion of the buffer that we'll actually write to. + Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0); + } + final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + groupingKey, + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET); + assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + + // Probe our map using the serialized key + final BytesToBytesMap.Location loc = map.lookup( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize); + if (!loc.isDefined()) { + // This is the first time that we've seen this grouping key, so we'll insert a copy of the + // empty aggregation buffer into the map: + loc.putNewKey( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize, + emptyAggregationBuffer, + PlatformDependent.LONG_ARRAY_OFFSET, + emptyAggregationBuffer.length + ); + } + + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentAggregationBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return currentAggregationBuffer; + } + + /** + * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. + */ + public static class MapEntry { + private MapEntry() { }; + public final UnsafeRow key = new UnsafeRow(); + public final UnsafeRow value = new UnsafeRow(); + } + + /** + * Returns an iterator over the keys and values in this map. + * + * For efficiency, each call returns the same object. + */ + public Iterator iterator() { + return new Iterator() { + + private final MapEntry entry = new MapEntry(); + private final Iterator mapLocationIterator = map.iterator(); + + @Override + public boolean hasNext() { + return mapLocationIterator.hasNext(); + } + + @Override + public MapEntry next() { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + groupingKeySchema + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return entry; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Free the unsafe memory associated with this map. + */ + public void free() { + map.free(); + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java new file mode 100644 index 0000000000000..0a358ed408aa1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import scala.collection.Map; +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.sql.Date; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.DataType; +import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UTF8String; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.bitset.BitSetMethods; + +/** + * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. + * + * Each tuple has three parts: [null bit set] [values] [variable length portion] + * + * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores + * one bit per field. + * + * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length + * primitive types, such as long, double, or int, we store the value directly in the word. For + * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the + * base address of the row) that points to the beginning of the variable-length field. + * + * Instances of `UnsafeRow` act as pointers to row data stored in this format. + */ +public final class UnsafeRow implements MutableRow { + + private Object baseObject; + private long baseOffset; + + Object getBaseObject() { return baseObject; } + long getBaseOffset() { return baseOffset; } + + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ + private int numFields; + + /** The width of the null tracking bit set, in bytes */ + private int bitSetWidthInBytes; + /** + * This optional schema is required if you want to call generic get() and set() methods on + * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() + * methods. This should be removed after the planned InternalRow / Row split; right now, it's only + * needed by the generic get() method, which is only called internally by code that accesses + * UTF8String-typed columns. + */ + @Nullable + private StructType schema; + + private long getFieldOffset(int ordinal) { + return baseOffset + bitSetWidthInBytes + ordinal * 8L; + } + + public static int calculateBitSetWidthInBytes(int numFields) { + return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + } + + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set settableFieldTypes; + + /** + * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). + */ + public static final Set readableFieldTypes; + + static { + settableFieldTypes = Collections.unmodifiableSet( + new HashSet( + Arrays.asList(new DataType[] { + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType + }))); + + // We support get() on a superset of the types for which we support set(): + final Set _readableFieldTypes = new HashSet( + Arrays.asList(new DataType[]{ + StringType + })); + _readableFieldTypes.addAll(settableFieldTypes); + readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); + } + + /** + * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, + * since the value returned by this constructor is equivalent to a null pointer. + */ + public UnsafeRow() { } + + /** + * Update this UnsafeRow to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param numFields the number of fields in this row + * @param schema an optional schema; this is necessary if you want to call generic get() or set() + * methods on this row, but is optional if the caller will only use type-specific + * getTYPE() and setTYPE() methods. + */ + public void pointTo( + Object baseObject, + long baseOffset, + int numFields, + @Nullable StructType schema) { + assert numFields >= 0 : "numFields should >= 0"; + assert schema == null || schema.fields().length == numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.numFields = numFields; + this.schema = schema; + } + + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numFields : "index (" + index + ") should <= " + numFields; + } + + @Override + public void setNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.set(baseObject, baseOffset, i); + // To preserve row equality, zero out the value when setting the column to null. + // Since this row does does not currently support updates to variable-length values, we don't + // have to worry about zeroing out that data. + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + } + + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setLong(int ordinal, long value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setDouble(int ordinal, double value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setShort(int ordinal, short value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setByte(int ordinal, byte value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setFloat(int ordinal, float value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } + + @Override + public int size() { + return numFields; + } + + @Override + public int length() { + return size(); + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Object apply(int i) { + return get(i); + } + + @Override + public Object get(int i) { + assertIndexIsValid(i); + assert (schema != null) : "Schema must be defined when calling generic get() method"; + final DataType dataType = schema.fields()[i].dataType(); + // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic + // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to + // separate the internal and external row interfaces, then internal code can fetch strings via + // a new getUTF8String() method and we'll be able to remove this method. + if (isNullAt(i)) { + return null; + } else if (dataType == StringType) { + return getUTF8String(i); + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean isNullAt(int i) { + assertIndexIsValid(i); + return BitSetMethods.isSet(baseObject, baseOffset, i); + } + + @Override + public boolean getBoolean(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + } + + @Override + public byte getByte(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + } + + @Override + public short getShort(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + } + + @Override + public int getInt(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + } + + @Override + public long getLong(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + } + + @Override + public float getFloat(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + } + } + + @Override + public double getDouble(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + } + } + + public UTF8String getUTF8String(int i) { + assertIndexIsValid(i); + final UTF8String str = new UTF8String(); + final long offsetToStringSize = getLong(i); + final int stringSizeInBytes = + (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); + final byte[] strBytes = new byte[stringSizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + stringSizeInBytes + ); + str.set(strBytes); + return str; + } + + @Override + public String getString(int i) { + return getUTF8String(i).toString(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean anyNull() { + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + } + + @Override + public Seq toSeq() { + final ArraySeq values = new ArraySeq(numFields); + for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { + values.update(fieldNumber, get(fieldNumber)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 1f3c02478bd68..2eb3e167baad5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -25,10 +25,6 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ -private[sql] object KeywordNormalizer { - def apply(str: String): String = str.toLowerCase() -} - private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { @@ -42,7 +38,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize: String = KeywordNormalizer(str) + def normalize: String = lexical.normalizeKeyword(str) def parser: Parser[String] = normalize } @@ -90,13 +86,16 @@ class SqlLexical extends StdLexical { reserved ++= keywords } + /* Normal the keyword string */ + def normalizeKeyword(str: String): String = str.toLowerCase + delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" ) protected override def processIdent(name: String) = { - val token = KeywordNormalizer(name) + val token = normalizeKeyword(name) if (reserved contains token) Keyword(token) else Identifier(name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala new file mode 100644 index 0000000000000..977003493d471 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/Dialect.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Root class of SQL Parser Dialect, and we don't guarantee the binary + * compatibility for the future release, let's keep it as the internal + * interface for advanced user. + * + */ +@DeveloperApi +abstract class Dialect { + // this is the main function that will be implemented by sql parser. + def parse(sqlText: String): LogicalPlan +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0af969cc5cc67..1d3a2dc0d9bb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -365,6 +365,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) + | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } | primary ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,22 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) child else Cast(child, expected) + } + e.withNewChildren(newC) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5d5aba9644ff7..fa6cc7a1a36cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -278,12 +278,6 @@ package object dsl { def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) - def sample( - fraction: Double, - withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt): LogicalPlan = - Sample(fraction, withReplacement, seed, logicalPlan) - // TODO specify the output column names def generate( generator: Generator, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index bdeb660b1ecb7..0fd4f9b374ee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -38,6 +38,8 @@ package object errors { } } + class DialectException(msg: String, cause: Throwable) extends Exception(msg, cause) + /** * Wraps any exceptions that are thrown while executing `f` in a * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 3475ed05f4454..aa4099e4d7bf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -202,6 +202,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case DoubleType => new MutableDouble case BooleanType => new MutableBoolean case LongType => new MutableLong + case DateType => new MutableInt // We use INT for DATE internally case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala new file mode 100644 index 0000000000000..5b2c8572784bd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** + * Converts Rows into UnsafeRow format. This class is NOT thread-safe. + * + * @param fieldTypes the data types of the row's columns. + */ +class UnsafeRowConverter(fieldTypes: Array[DataType]) { + + def this(schema: StructType) { + this(schema.fields.map(_.dataType)) + } + + /** Re-used pointer to the unsafe row being written */ + private[this] val unsafeRow = new UnsafeRow() + + /** Functions for encoding each column */ + private[this] val writers: Array[UnsafeColumnWriter] = { + fieldTypes.map(t => UnsafeColumnWriter.forType(t)) + } + + /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ + private[this] val fixedLengthSize: Int = + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + + /** + * Compute the amount of space, in bytes, required to encode the given row. + */ + def getSizeRequirement(row: Row): Int = { + var fieldNumber = 0 + var variableLengthFieldSize: Int = 0 + while (fieldNumber < writers.length) { + if (!row.isNullAt(fieldNumber)) { + variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) + } + fieldNumber += 1 + } + fixedLengthSize + variableLengthFieldSize + } + + /** + * Convert the given row into UnsafeRow format. + * + * @param row the row to convert + * @param baseObject the base object of the destination address + * @param baseOffset the base offset of the destination address + * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. + */ + def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + var fieldNumber = 0 + var appendCursor: Int = fixedLengthSize + while (fieldNumber < writers.length) { + if (row.isNullAt(fieldNumber)) { + unsafeRow.setNullAt(fieldNumber) + } else { + appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + } + fieldNumber += 1 + } + appendCursor + } + +} + +/** + * Function for writing a column into an UnsafeRow. + */ +private abstract class UnsafeColumnWriter { + /** + * Write a value into an UnsafeRow. + * + * @param source the row being converted + * @param target a pointer to the converted unsafe row + * @param column the column to write + * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * used for calculating where variable-length data should be written + * @return the number of variable-length bytes written + */ + def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int + + /** + * Return the number of bytes that are needed to write this variable-length value. + */ + def getSize(source: Row, column: Int): Int +} + +private object UnsafeColumnWriter { + + def forType(dataType: DataType): UnsafeColumnWriter = { + dataType match { + case NullType => NullUnsafeColumnWriter + case BooleanType => BooleanUnsafeColumnWriter + case ByteType => ByteUnsafeColumnWriter + case ShortType => ShortUnsafeColumnWriter + case IntegerType => IntUnsafeColumnWriter + case LongType => LongUnsafeColumnWriter + case FloatType => FloatUnsafeColumnWriter + case DoubleType => DoubleUnsafeColumnWriter + case StringType => StringUnsafeColumnWriter + case t => + throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + } + } +} + +// ------------------------------------------------------------------------------------------------ + +private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter +private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter +private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter +private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter +private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter +private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter +private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter +private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter +private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + +private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { + // Primitives don't write to the variable-length region: + def getSize(sourceRow: Row, column: Int): Int = 0 +} + +private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setNullAt(column) + 0 + } +} + +private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setBoolean(column, source.getBoolean(column)) + 0 + } +} + +private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setByte(column, source.getByte(column)) + 0 + } +} + +private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setShort(column, source.getShort(column)) + 0 + } +} + +private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setInt(column, source.getInt(column)) + 0 + } +} + +private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setLong(column, source.getLong(column)) + 0 + } +} + +private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setFloat(column, source.getFloat(column)) + 0 + } +} + +private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setDouble(column, source.getDouble(column)) + 0 + } +} + +private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: Row, column: Int): Int = { + val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } + + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + val value = source.get(column).asInstanceOf[UTF8String] + val baseObject = target.getBaseObject + val baseOffset = target.getBaseOffset + val numBytes = value.getBytes.length + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + PlatformDependent.copyMemory( + value.getBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor + 8, + numBytes + ) + target.setLong(column, appendCursor) + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index dbc92fb93e95e..d17af0e7ff87e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -672,6 +672,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => ru.Literal(Constant(-1.toDouble)) case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) + case DateType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..fcc06d3aa1036 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala new file mode 100644 index 0000000000000..dc68469e060cb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} +import org.apache.spark.sql.types._ + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(f: Double => Double, name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + type EvaluatedType = Any + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } +} + +case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") + +case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") + +case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") + +case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") + +case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") + +case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") + +case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") + +case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") + +case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") + +case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") + +case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") + +case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") + +case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") + +case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") + +case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") + +case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") + +case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") + +case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") + +case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") + +case class ToDegrees(child: Expression) + extends MathematicalExpression(math.toDegrees, "DEGREES") + +case class ToRadians(child: Expression) + extends MathematicalExpression(math.toRadians, "RADIANS") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2d03fbfb0d311..709f7d672d931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,7 +36,13 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Combine Limits", FixedPoint(100), + Batch("Operator Reordering", FixedPoint(100), + UnionPushdown, + CombineFilters, + PushPredicateThroughProject, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + ColumnPruning, CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), NullPropagation, @@ -49,13 +55,6 @@ object DefaultOptimizer extends Optimizer { OptimizeIn) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: - Batch("Filter Pushdown", FixedPoint(100), - UnionPushdown, - CombineFilters, - PushPredicateThroughProject, - PushPredicateThroughJoin, - PushPredicateThroughGenerate, - ColumnPruning) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Nil } @@ -171,6 +170,9 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(substitutedProjection, child) + case Project(projectList, Limit(exp, child)) => + Limit(exp, Project(projectList, child)) + // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child } @@ -569,7 +571,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case Inner => + case _ @ (Inner | LeftSemi) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -577,7 +579,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, Inner, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -586,14 +588,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, RightOuter, newJoinCond) - case _ @ (LeftOuter | LeftSemi) => + case LeftOuter => // push down the right side only join filter for right sub query val newLeft = left val newRight = rightJoinConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bbc94a7ab3398..21208c8a5c281 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } -case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) - extends UnaryNode { +/** + * Sample the dataset. + * + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LogicalPlan + */ +case class Sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } @@ -310,6 +324,17 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user + * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer + * of the output requires some specific ordering or distribution of the data. + */ +case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + /** * A relation with one row. This is used in "SELECT ..." without a from clause. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index e737418d9c3bc..63df2c1ee72ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -32,5 +32,11 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData -case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) +/** + * This method repartitions data using [[Expression]]s, and receives information about the + * number of partitions during execution. Used when a specific ordering or distribution is + * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * `coalesce` and `repartition`. + */ +case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..fa71001c9336e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -1152,6 +1153,158 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 0000000000000..7a19e511eb8b5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} + +import org.apache.spark.sql.types._ + +class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + + private var memoryManager: TaskMemoryManager = null + + override def beforeEach(): Unit = { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + } + + override def afterEach(): Unit = { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory() + memoryManager = null + } + } + + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + test("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + assert(!map.iterator().hasNext) + map.free() + } + + test("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + map.getAggregationBuffer(groupKey) + val iter = map.iterator() + val entry = iter.next() + assert(!iter.hasNext) + entry.key.getString(0) should be ("cats") + entry.value.getInt(0) should be (0) + + // Modifications to rows retrieved from the map should update the values in the map + entry.value.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + test("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 128, // initial capacity + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString)))) + } + val seenKeys: Set[String] = map.iterator().asScala.map { entry => + entry.key.getString(0) + }.toSet + seenKeys.size should be (groupKeys.size) + seenKeys should be (groupKeys) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala new file mode 100644 index 0000000000000..3a60c7fd32675 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Arrays + +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +class UnsafeRowConverterSuite extends FunSuite with Matchers { + + test("basic conversion with only primitive types") { + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setLong(1, 1) + row.setInt(2, 2) + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (3 * 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getLong(1) should be (1) + unsafeRow.getInt(2) should be (2) + } + + test("basic conversion with primitive and string types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.setString(2, "World") + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + unsafeRow.getString(2) should be ("World") + } + + test("null handling") { + val fieldTypes: Array[DataType] = Array( + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + val converter = new UnsafeRowConverter(fieldTypes) + + val rowWithAllNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + for (i <- 0 to fieldTypes.length - 1) { + r.setNullAt(i) + } + r + } + + val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) + val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow( + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val createdFromNull = new UnsafeRow() + createdFromNull.pointTo( + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + for (i <- 0 to fieldTypes.length - 1) { + assert(createdFromNull.isNullAt(i)) + } + createdFromNull.getBoolean(1) should be (false) + createdFromNull.getByte(2) should be (0) + createdFromNull.getShort(3) should be (0) + createdFromNull.getInt(4) should be (0) + createdFromNull.getLong(5) should be (0) + assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) + assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + + // If we have an UnsafeRow with columns that are initially non-null and we null out those + // columns, then the serialized row representation should be identical to what we would get by + // creating an entirely null row via the converter + val rowWithNoNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + r.setNullAt(0) + r.setBoolean(1, false) + r.setByte(2, 20) + r.setShort(3, 30) + r.setInt(4, 400) + r.setLong(5, 500) + r.setFloat(6, 600) + r.setDouble(7, 700) + r + } + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + converter.writeRow( + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + val setToNullAfterCreation = new UnsafeRow() + setToNullAfterCreation.pointTo( + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + + setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) + setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) + setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) + setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) + setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) + setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) + setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) + setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + + for (i <- 0 to fieldTypes.length - 1) { + setToNullAfterCreation.setNullAt(i) + } + assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 2d16d668fd522..a30052b38fc11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,6 +27,8 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = + Batch("Filter Pushdown", FixedPoint(100), + ColumnPruning) :: Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), @@ -69,4 +71,21 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("limits: combines two limits after ColumnPruning") { + val originalQuery = + testRelation + .select('a) + .limit(2) + .select('a) + .limit(5) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index aa9708b164efa..58d415d9011e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.{Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -43,6 +43,8 @@ class FilterPushdownSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + // This test already passes. test("eliminate subqueries") { val originalQuery = @@ -90,7 +92,23 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("column pruning for Project(ne, Limit)") { + val originalQuery = + testRelation + .select('a,'b) + .limit(2) + .select('a) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -197,6 +215,23 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: push down left semi join") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b >= 1) + val right = testRelation1.where('d >= 2) + val correctAnswer = + left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + test("joins: push down left outer join #1") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ca6ae482eb2ab..c421006c8fd2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -330,6 +330,17 @@ class DataFrame private[sql]( */ def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) + /** + * Returns a [[DataFrameStatFunctions]] for working statistic functions support. + * {{{ + * // Finding frequent items in column with name 'a'. + * df.stat.freqItems(Seq("a")) + * }}} + * + * @group dfops + */ + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this) + /** * Cartesian join with another [[DataFrame]]. * @@ -623,17 +634,12 @@ class DataFrame private[sql]( } /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} + * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg("age" -> "max", "salary" -> "avg") + * df.groupBy().agg("age" -> "max", "salary" -> "avg") + * }} * @group dfops */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { @@ -711,7 +717,7 @@ class DataFrame private[sql]( * @group dfops */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { - Sample(fraction, withReplacement, seed, logicalPlan) + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } /** @@ -725,6 +731,42 @@ class DataFrame private[sql]( sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * Randomly splits this [[DataFrame]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * @group dfops + */ + def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan)) + }.toArray + } + + /** + * Randomly splits this [[DataFrame]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @group dfops + */ + def randomSplit(weights: Array[Double]): Array[DataFrame] = { + randomSplit(weights, Utils.random.nextLong) + } + + /** + * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * @group dfops + */ + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { + randomSplit(weights.toArray, seed) + } + /** * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of @@ -809,15 +851,40 @@ class DataFrame private[sql]( /** * Returns a new [[DataFrame]] with a column renamed. + * This is a no-op if schema doesn't contain existingName. * @group dfops */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { val resolver = sqlContext.analyzer.resolver - val colNames = schema.map { field => - val name = field.name - if (resolver(name, existingName)) Column(name).as(newName) else Column(name) + val shouldRename = schema.exists(f => resolver(f.name, existingName)) + if (shouldRename) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, existingName)) Column(name).as(newName) else Column(name) + } + select(colNames : _*) + } else { + this + } + } + + /** + * Returns a new [[DataFrame]] with a column dropped. + * This is a no-op if schema doesn't contain column name. + * @group dfops + */ + def drop(colName: String): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val shouldDrop = schema.exists(f => resolver(f.name, colName)) + if (shouldDrop) { + val colsAfterDrop = schema.filter { field => + val name = field.name + !resolver(name, colName) + }.map(f => Column(f.name)) + select(colsAfterDrop : _*) + } else { + this } - select(colNames :_*) } /** @@ -961,9 +1028,7 @@ class DataFrame private[sql]( * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), - schema, needsConversion = false) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -974,10 +1039,7 @@ class DataFrame private[sql]( * @group rdd */ override def coalesce(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.coalesce(numPartitions), - schema, - needsConversion = false) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala new file mode 100644 index 0000000000000..42e5cbc05e1e0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.stat.FrequentItems + +/** + * :: Experimental :: + * Statistic functions for [[DataFrame]]s. + */ +@Experimental +final class DataFrameStatFunctions private[sql](df: DataFrame) { + + /** + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. + * The `support` should be greater than 1e-4. + * + * @param cols the names of the columns to search frequent items in. + * @param support The minimum frequency for an item to be considered `frequent`. Should be greater + * than 1e-4. + * @return A Local DataFrame with the Array of frequent items for each column. + */ + def freqItems(cols: Array[String], support: Double): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, support) + } + + /** + * Runs `freqItems` with a default `support` of 1%. + * + * @param cols the names of the columns to search frequent items in. + * @return A Local DataFrame with the Array of frequent items for each column. + */ + def freqItems(cols: Array[String]): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, 0.01) + } + + /** + * Python friendly implementation for `freqItems` + */ + def freqItems(cols: List[String], support: Double): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, support) + } + + /** + * Python friendly implementation for `freqItems` with a default `support` of 1%. + */ + def freqItems(cols: List[String]): DataFrame = { + FrequentItems.singlePassFreqItems(df, cols, 0.01) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4fc5de7e824fe..2fa602a6082dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,6 +30,7 @@ private[spark] object SQLConf { val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" @@ -149,6 +150,14 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + /** + * When set to true, Spark SQL will use managed memory for certain operations. This option only + * takes effect if codegen is enabled. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a279b0f07c38a..77f51dfd88d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import com.google.common.reflect.TypeToken @@ -32,9 +33,11 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.Dialect import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} @@ -44,6 +47,45 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.{Partition, SparkContext} +/** + * Currently we support the default dialect named "sql", associated with the class + * [[DefaultDialect]] + * + * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: + * {{{ + *-- switch to "hiveql" dialect + * spark-sql>SET spark.sql.dialect=hiveql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- switch to "sql" dialect + * spark-sql>SET spark.sql.dialect=sql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- register the new SQL dialect + * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- register the non-exist SQL dialect + * spark-sql> SET spark.sql.dialect=NotExistedClass; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- Exception will be thrown and switch to dialect + *-- "sql" (for SQLContext) or + *-- "hiveql" (for HiveContext) + * }}} + */ +private[spark] class DefaultDialect extends Dialect { + @transient + protected val sqlParser = { + val catalystSqlParser = new catalyst.SqlParser + new SparkSQLParser(catalystSqlParser.parse) + } + + override def parse(sqlText: String): LogicalPlan = { + sqlParser.parse(sqlText) + } +} + /** * The entry point for working with structured data (rows and columns) in Spark. Allows the * creation of [[DataFrame]] objects as well as the execution of SQL queries. @@ -132,17 +174,27 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient - protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) - - @transient - protected[sql] val sqlParser = { - val fallback = new catalyst.SqlParser - new SparkSQLParser(fallback.parse(_)) + protected[sql] val ddlParser = new DDLParser((sql: String) => { getSQLDialect().parse(sql) }) + + protected[sql] def getSQLDialect(): Dialect = { + try { + val clazz = Utils.classForName(dialectClassName) + clazz.newInstance().asInstanceOf[Dialect] + } catch { + case NonFatal(e) => + // Since we didn't find the available SQL Dialect, it will fail even for SET command: + // SET spark.sql.dialect=sql; Let's reset as default dialect automatically. + val dialect = conf.dialect + // reset the sql dialect + conf.unsetConf(SQLConf.DIALECT) + // throw out the exception, and the default sql dialect will take effect for next query. + throw new DialectException( + s"""Instantiating dialect '$dialect' failed. + |Reverting to default dialect '${conf.dialect}'""".stripMargin, e) + } } - protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql)) - } + protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) @@ -156,6 +208,12 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val defaultSession = createSession() + protected[sql] def dialectClassName = if (conf.dialect == "sql") { + classOf[DefaultDialect].getCanonicalName + } else { + conf.dialect + } + sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) case _ => @@ -945,11 +1003,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group basic */ def sql(sqlText: String): DataFrame = { - if (conf.dialect == "sql") { - DataFrame(this, parseSql(sqlText)) - } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}") - } + DataFrame(this, parseSql(sqlText)) } /** @@ -1011,6 +1065,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def codegenEnabled: Boolean = self.conf.codegenEnabled + def unsafeEnabled: Boolean = self.conf.unsafeEnabled + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 1fd387eec7e57..57effbf7ec501 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -84,7 +84,7 @@ object RDDConversions { } /** Logical plan node for scanning data from an RDD. */ -case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) +private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { override def children: Seq[LogicalPlan] = Nil @@ -105,11 +105,12 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont } /** Physical plan node for scanning data from an RDD. */ -case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { +private[sql] case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute(): RDD[Row] = rdd } /** Logical plan node for scanning data from a local collection. */ +private[sql] case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b1ef6556de1e9..5d9f202681045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ @@ -40,6 +41,7 @@ case class AggregateEvaluation( * ensure all values where `groupingExpressions` are equal are present. * @param groupingExpressions expressions that are evaluated to determine grouping. * @param aggregateExpressions expressions that are computed for each group. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. * @param child the input data source. */ @DeveloperApi @@ -47,6 +49,7 @@ case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], + unsafeEnabled: Boolean, child: SparkPlan) extends UnaryNode { @@ -225,6 +228,21 @@ case class GeneratedAggregate( case e: Expression if groupMap.contains(e) => groupMap(e) }) + val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +283,49 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) + } else if (unsafeEnabled && schemaSupportsUnsafe) { + log.info("Using Unsafe-based aggregator") + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, + TaskContext.get.taskMemoryManager(), + 1024 * 16, // initial capacity + false // disable tracking of performance metrics + ) + + while (iter.hasNext) { + val currentRow: Row = iter.next() + val groupKey: Row = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): Row = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } + } + } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[Row, MutableRow]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 8a8c3a404323a..ace9af5f384c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute /** * Physical plan node for scanning data from a local collection. */ -case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { +private[sql] case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { private lazy val rdd = sqlContext.sparkContext.parallelize(rows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 030ef118f75d4..326e8ce4ca524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -136,10 +136,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, + unsafeEnabled, execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, + unsafeEnabled, planLater(child))) :: Nil // Cases where some aggregate can not be codegened @@ -283,7 +285,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil - + case logical.Repartition(numPartitions, shuffle, child) => + execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. @@ -300,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Expand(projections, output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - case logical.Sample(fraction, withReplacement, seed, child) => - execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case logical.Sample(lb, ub, withReplacement, seed, child) => + execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => @@ -317,7 +320,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil - case logical.Repartition(expressions, child) => + case logical.RepartitionByExpression(expressions, child) => execution.Exchange( HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d286fe81bee5f..5ca11e67a9434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { /** * :: DeveloperApi :: + * Sample the dataset. + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the QueryPlan */ @DeveloperApi -case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) +case class Sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output // TODO: How to pick seed? override def execute(): RDD[Row] = { - child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + if (withReplacement) { + child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + } else { + child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + } } } @@ -245,6 +261,20 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { } } +/** + * :: DeveloperApi :: + * Return a new RDD that has exactly `numPartitions` partitions. + */ +@DeveloperApi +case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).coalesce(numPartitions, shuffle) + } +} + /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 99f24910fd61f..98df5bef34efa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -42,7 +42,7 @@ trait RunnableCommand extends logical.Command { * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { +private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala new file mode 100644 index 0000000000000..9ac732b55b188 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.types.{LongType, DataType} + +/** + * Returns monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + * represent the record number within each partition. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * Since this expression is stateful, it cannot be a case object. + */ +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { + + /** + * Record ID within each partition. By being transient, count's value is reset to 0 every time + * we serialize and deserialize it. + */ + @transient private[this] var count: Long = 0L + + override type EvaluatedType = Long + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def eval(input: Row): Long = { + val currentCount = count + count += 1 + (TaskContext.get().partitionId().toLong << 33) + currentCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index fe7607c6ac340..c2c6cbd491598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { - self: Product => +private[sql] case object SparkPartitionID extends LeafExpression { override type EvaluatedType = Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 56200f6b8c8a9..6aaf35fb429e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -59,10 +59,7 @@ case class BroadcastNestedLoopJoin( } @transient private lazy val boundCondition = - InterpretedPredicate.create( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) override def execute(): RDD[Row] = { val broadcastedRelation = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index e06f63f94b78b..b03af410dca08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -45,10 +45,7 @@ case class LeftSemiJoinBNL( override def right: SparkPlan = broadcast @transient private lazy val boundCondition = - InterpretedPredicate.create( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) override def execute(): RDD[Row] = { val broadcastedRelation = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala new file mode 100644 index 0000000000000..5ae7e107544f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -0,0 +1,121 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.stat + +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.Logging +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.{ArrayType, StructField, StructType} + +private[sql] object FrequentItems extends Logging { + + /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ + private class FreqItemCounter(size: Int) extends Serializable { + val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] + + /** + * Add a new example to the counts if it exists, otherwise deduct the count + * from existing items. + */ + def add(key: Any, count: Long): this.type = { + if (baseMap.contains(key)) { + baseMap(key) += count + } else { + if (baseMap.size < size) { + baseMap += key -> count + } else { + // TODO: Make this more efficient... A flatMap? + baseMap.retain((k, v) => v > count) + baseMap.transform((k, v) => v - count) + } + } + this + } + + /** + * Merge two maps of counts. + * @param other The map containing the counts for that partition + */ + def merge(other: FreqItemCounter): this.type = { + other.baseMap.foreach { case (k, v) => + add(k, v) + } + this + } + } + + /** + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. + * The `support` should be greater than 1e-4. + * For Internal use only. + * + * @param df The input DataFrame + * @param cols the names of the columns to search frequent items in + * @param support The minimum frequency for an item to be considered `frequent`. Should be greater + * than 1e-4. + * @return A Local DataFrame with the Array of frequent items for each column. + */ + private[sql] def singlePassFreqItems( + df: DataFrame, + cols: Seq[String], + support: Double): DataFrame = { + require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") + val numCols = cols.length + // number of max items to keep counts for + val sizeOfMap = (1 / support).toInt + val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) + val originalSchema = df.schema + val colInfo = cols.map { name => + val index = originalSchema.fieldIndex(name) + (name, originalSchema.fields(index).dataType) + } + + val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( + seqOp = (counts, row) => { + var i = 0 + while (i < numCols) { + val thisMap = counts(i) + val key = row.get(i) + thisMap.add(key, 1L) + i += 1 + } + counts + }, + combOp = (baseCounts, counts) => { + var i = 0 + while (i < numCols) { + baseCounts(i).merge(counts(i)) + i += 1 + } + baseCounts + } + ) + val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val resultRow = Row(justItems:_*) + // append frequent Items to the column name for easy debugging + val outputCols = colInfo.map { v => + StructField(v._1 + "_freqItems", ArrayType(v._2, false)) + } + val schema = StructType(outputCols).toAttributes + new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9738fd4f93bad..aa31d04a0cbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -301,6 +301,22 @@ object functions { */ def lower(e: Column): Column = Lower(e.expr) + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + */ + def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index f3b5455574d1a..2f6ba48dbc3d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -37,7 +37,7 @@ private[sql] object JDBCRDD extends Logging { * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. */ - private def getCatalystType(sqlType: Int): DataType = { + private def getCatalystType(sqlType: Int, precision: Int, scale: Int): DataType = { val answer = sqlType match { case java.sql.Types.ARRAY => null case java.sql.Types.BIGINT => LongType @@ -49,6 +49,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.CLOB => StringType case java.sql.Types.DATALINK => null case java.sql.Types.DATE => DateType + case java.sql.Types.DECIMAL + if precision != 0 || scale != 0 => DecimalType(precision, scale) case java.sql.Types.DECIMAL => DecimalType.Unlimited case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType @@ -61,6 +63,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.NCHAR => StringType case java.sql.Types.NCLOB => StringType case java.sql.Types.NULL => null + case java.sql.Types.NUMERIC + if precision != 0 || scale != 0 => DecimalType(precision, scale) case java.sql.Types.NUMERIC => DecimalType.Unlimited case java.sql.Types.NVARCHAR => StringType case java.sql.Types.OTHER => null @@ -109,10 +113,11 @@ private[sql] object JDBCRDD extends Logging { val dataType = rsmd.getColumnType(i + 1) val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) - if (columnType == null) columnType = getCatalystType(dataType) + if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } @@ -154,7 +159,7 @@ private[sql] object JDBCRDD extends Logging { def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { - if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) + if (driver != null) DriverRegistry.register(driver) } catch { case e: ClassNotFoundException => { logWarning(s"Couldn't find class $driver", e); @@ -307,6 +312,7 @@ private[sql] class JDBCRDD( case BooleanType => BooleanConversion case DateType => DateConversion case DecimalType.Unlimited => DecimalConversion + case DecimalType.Fixed(d) => DecimalConversion case DoubleType => DoubleConversion case FloatType => FloatConversion case IntegerType => IntegerConversion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 5f480083d5a49..d6b3fb3291a2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -100,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider { val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) + if (driver != null) DriverRegistry.register(driver) if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { @@ -136,7 +136,7 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName + val driver: String = DriverRegistry.getDriverClassName(url) JDBCRDD.scanTable( sqlContext.sparkContext, schema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index d4e0abc040bc6..ae9af1eabe68e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql -import java.sql.{Connection, DriverManager, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement} +import java.util.Properties + +import scala.collection.mutable import org.apache.spark.Logging import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils package object jdbc { private[sql] object JDBCWriteDetails extends Logging { @@ -179,4 +183,58 @@ package object jdbc { } } + + private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion + } + + /** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped + * into thin wrapper. + */ + private [sql] object DriverRegistry extends Logging { + + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + synchronized { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } + } + } // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..d901542b7eaf3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.functions.lit + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta).= + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Computes the cube-root of the given value. + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the ceiling of the given value. + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the cosine of the given value. + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the exponential of the given value. + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Computes the floor of the given value. + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes the natural logarithm of the given value. + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the signum of the given value. + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the sine of the given value. + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the tangent of the given value. + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..f5ce2718bec4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + + if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { + try { + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + try { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } catch { + case e: Exception => LOG.warn("could not write success file for " + outputPath, e) + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a938b77578686..aded126ea0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -381,6 +381,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} + var committer: OutputCommitter = null // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -403,6 +404,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index e7a0685e013d8..1abf3aa51cb25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -38,12 +38,12 @@ private[sql] class DDLParser( parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with DataTypeParser with Logging { - def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { try { - Some(parse(input)) + parse(input) } catch { case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => None + case _ if !exceptionOnError => parseQuery(input) case x: Throwable => throw x } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e02c84872c628..966d879e1fc9f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -22,10 +22,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.TestData$; +import org.apache.spark.sql.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; @@ -41,6 +38,7 @@ import java.util.Map; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -98,6 +96,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore @@ -169,5 +175,12 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } } - + + @Test + public void testFrequentItems() { + DataFrame df = context.table("testData2"); + String[] cols = new String[]{"a"}; + DataFrame results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 904073b8cb2aa..2ba5fc21ff57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -310,6 +309,17 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + test("sparkPartitionId") { val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala new file mode 100644 index 0000000000000..bb1d29c71d23b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite +import org.scalatest.Matchers._ + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +class DataFrameStatSuite extends FunSuite { + + val sqlCtx = TestSQLContext + + test("Frequent Items") { + def toLetter(i: Int): String = (i + 96).toChar.toString + val rows = Array.tabulate(1000) { i => + if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) + } + val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles") + + val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) + val items = results.collect().head + items.getSeq[Int](0) should contain (1) + items.getSeq[String](1) should contain (toLetter(1)) + + val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) + val items2 = singleColResults.collect().head + items2.getSeq[Double](0) should contain (-1.0) + + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5ec06d448e50f..e286fef23caa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -499,6 +499,22 @@ class DataFrameSuite extends QueryTest { Row(2) :: Row(3) :: Row(4) :: Nil) } + test("drop column using drop") { + val df = testData.drop("key") + checkAnswer( + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) + } + + test("drop unknown column (no-op)") { + val df = testData.drop("random") + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key","value")) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") @@ -510,6 +526,23 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("randomSplit") { + val n = 600 + val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("describe") { val describeTestData = Seq( ("Bob", 16, 176), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..9e19bb7482e9b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 59f9508444f25..bbf9ab113ca43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -132,11 +132,7 @@ object QueryTest { val errorMessage = s""" |Results do not match for query: - |${df.logicalPlan} - |== Analyzed Plan == - |${df.queryExecution.analyzed} - |== Physical Plan == - |${df.queryExecution.executedPlan} + |${df.queryExecution} |== Results == |${sideBySide( s"== Correct Answer - ${expectedAnswer.size} ==" +: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e02e69fda3f2..0ab8558c1db13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,13 +19,18 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} + import org.apache.spark.sql.types._ +/** A SQL Dialect for testing purpose, and it can not be nested type */ +class MyDialect extends DefaultDialect + class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -46,6 +51,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } + test("support table.star") { + checkAnswer( + sql( + """ + |SELECT r.* + |FROM testData l join testData2 r on (l.key = r.a) + """.stripMargin), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } + test("self join with alias in agg") { Seq(1,2,3) .map(i => (i, i.toString)) @@ -64,6 +79,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } + test("SQL Dialect Switching to a new SQL parser") { + val newContext = new SQLContext(TestSQLContext.sparkContext) + newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) + assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) + } + + test("SQL Dialect Switch to an invalid parser with alias") { + val newContext = new SQLContext(TestSQLContext.sparkContext) + newContext.sql("SET spark.sql.dialect=MyTestClass") + intercept[DialectException] { + newContext.sql("SELECT 1") + } + // test if the dialect set back to DefaultSQLDialect + assert(newContext.getSQLDialect().getClass === classOf[DefaultDialect]) + } + test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index db096af4535a9..b165ab2b1deb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -22,6 +22,7 @@ import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} import org.apache.spark.sql.test._ +import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException import org.scalatest.{FunSuite, BeforeAndAfter} import TestSQLContext._ @@ -256,12 +257,22 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } + test("test DATE types in cache") { + val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() + TestSQLContext + .jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().registerTempTable("mycached_date") + val cachedRows = sql("select * from mycached_date").collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. assert(rows(0).getAs[BigDecimal](2) .equals(new BigDecimal("123456789012345.54321543215432100000"))) + assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20)) } test("SQL query as table name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 97c0f439acf13..b504842053690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,6 +381,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.set("spark.sql.parquet.output.committer.class", + "parquet.hadoop.ParquetOutputCommitter") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 21dce8d8a565a..e322340094e6f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -183,7 +183,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index dd06b2620c5ee..1d8d0b5c322ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, InputStreamReader, PrintStream} import java.sql.Timestamp +import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.spark.sql.catalyst.Dialect + import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -42,6 +45,15 @@ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNative import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} import org.apache.spark.sql.types._ +/** + * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext + */ +private[hive] class HiveQLDialect extends Dialect { + override def parse(sqlText: String): LogicalPlan = { + HiveQl.parseSql(sqlText) + } +} + /** * An instance of the Spark SQL execution engine that integrates with data stored in Hive. * Configuration for Hive is read from hive-site.xml on the classpath. @@ -81,25 +93,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertCTAS: Boolean = getConf("spark.sql.hive.convertCTAS", "false").toBoolean - override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - @transient - protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_)) - - override def sql(sqlText: String): DataFrame = { - val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) - // TODO: Create a framework for registering parsers instead of just hardcoding if statements. - if (conf.dialect == "sql") { - super.sql(substituted) - } else if (conf.dialect == "hiveql") { - val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false) - DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted))) - } else { - sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") - } + protected[sql] lazy val substitutor = new VariableSubstitution() + + protected[sql] override def parseSql(sql: String): LogicalPlan = { + super.parseSql(substitutor.substitute(hiveconf, sql)) } + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution(plan) + /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a @@ -356,6 +359,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") { + classOf[HiveQLDialect].getCanonicalName + } else { + super.dialectClassName + } + @transient private val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0ea6d57b816c6..0a86519e1412b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -783,13 +783,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case (None, Some(perPartitionOrdering), None, None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } @@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, + Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, relation) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => val fraction = numerator.toDouble / denominator.toDouble - Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) case a: ASTNode => throw new NotImplementedError( s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9f17bca083d13..edeab5158df62 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -107,7 +107,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + + // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. + // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" + override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4f8d0ac0e7656..630dec8fa05a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.hive.{MetastoreRelation, HiveShim} +import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.DefaultDialect +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.hive.MetastoreRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -45,6 +48,9 @@ case class Order( state: String, month: Int) +/** A SQL Dialect for testing purpose, and it can not be nested type */ +class MyDialect extends DefaultDialect + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is @@ -229,6 +235,35 @@ class SQLQuerySuite extends QueryTest { setConf("spark.sql.hive.convertCTAS", originalConf) } + test("SQL Dialect Switching") { + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) + assert(getSQLDialect().getClass === classOf[MyDialect]) + assert(sql("SELECT 1").collect() === Array(Row(1))) + + // set the dialect back to the DefaultSQLDialect + sql("SET spark.sql.dialect=sql") + assert(getSQLDialect().getClass === classOf[DefaultDialect]) + sql("SET spark.sql.dialect=hiveql") + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + + // set invalid dialect + sql("SET spark.sql.dialect.abc=MyTestClass") + sql("SET spark.sql.dialect=abc") + intercept[Exception] { + sql("SELECT 1") + } + // test if the dialect set back to HiveQLDialect + getSQLDialect().getClass === classOf[HiveQLDialect] + + sql("SET spark.sql.dialect=MyTestClass") + intercept[DialectException] { + sql("SELECT 1") + } + // test if the dialect set back to HiveQLDialect + assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java new file mode 100644 index 0000000000000..8c0fdfa9c7478 --- /dev/null +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +/** + * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming + * to save the received data (by receivers) and associated metadata to a reliable storage, so that + * they can be recovered after driver failures. See the Spark documentation for more information + * on how to plug in your own custom implementation of a write ahead log. + */ +@org.apache.spark.annotation.DeveloperApi +public abstract class WriteAheadLog { + /** + * Write the record to the log and return a record handle, which contains all the information + * necessary to read back the written record. The time is used to the index the record, + * such that it can be cleaned later. Note that implementations of this abstract class must + * ensure that the written data is durable and readable (using the record handle) by the + * time this function returns. + */ + abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + + /** + * Read a written record based on the given record handle. + */ + abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + + /** + * Read and return an iterator of all the records that have been written but not yet cleaned up. + */ + abstract public Iterator readAll(); + + /** + * Clean all the records that are older than the threshold time. It can wait for + * the completion of the deletion. + */ + abstract public void clean(long threshTime, boolean waitForCompletion); + + /** + * Close this log and release any resources. + */ + abstract public void close(); +} diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java new file mode 100644 index 0000000000000..02324189b7822 --- /dev/null +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util; + +/** + * This abstract class represents a handle that refers to a record written in a + * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. + * It must contain all the information necessary for the record to be read and returned by + * an implemenation of the WriteAheadLog class. + * + * @see org.apache.spark.streaming.util.WriteAheadLog + */ +@org.apache.spark.annotation.DeveloperApi +public abstract class WriteAheadLogRecordHandle implements java.io.Serializable { +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 0a50485118588..7bfae253c3a0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,7 +77,8 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -85,6 +86,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) + val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -160,7 +162,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -234,15 +236,24 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = - { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None (if ignoreReadError = true), + * or throw exception (if ignoreReadError = false). + */ + def read( + checkpointDir: String, + conf: SparkConf, + hadoopConf: Configuration, + ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse if (checkpointFiles.isEmpty) { return None } @@ -282,7 +293,10 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) + if (!ignoreReadError) { + throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + } + None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f57f295874645..90c8b47aebce0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,6 +107,19 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) + /** + * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. + * @param path Path to the directory that was specified as the checkpoint directory + * @param sparkContext Existing SparkContext + */ + def this(path: String, sparkContext: SparkContext) = { + this( + sparkContext, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + null) + } + + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -115,10 +128,12 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { + if (sc_ != null) { + sc_ + } else if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - sc_ + throw new SparkException("Cannot create StreamingContext without a SparkContext") } } @@ -129,7 +144,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = SparkEnv.get + private[streaming] val env = sc.env private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -174,7 +189,9 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - SparkEnv.get.metricsSystem.registerSource(streamingSource) + assert(env != null) + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -621,19 +638,59 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = try { - CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) - } catch { - case e: Exception => - if (createOnError) { - None - } else { - throw e - } - } + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), hadoopConf, createOnError) checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext + ): StreamingContext = { + getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext, + createOnError: Boolean + ): StreamingContext = { + val checkpointOption = CheckpointReader.read( + checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) + checkpointOption.map(new StreamingContext(sparkContext, _, null)) + .getOrElse(creatingFunc(sparkContext)) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4095a7cc84946..572d7d8e8753d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,13 +32,14 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver +import org.apache.hadoop.conf.Configuration /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -655,6 +656,7 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ + @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -676,6 +678,7 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -700,6 +703,7 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -712,6 +716,117 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext] + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 24f99a2b929f5..83d41f5762444 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -626,7 +626,7 @@ abstract class DStream[T: ClassTag] ( println("Time: " + time) println("-------------------------------------------") firstNum.take(num).foreach(println) - if (firstNum.size > num) println("...") + if (firstNum.length > num) println("...") println() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 8be04314c4285..4c7fd2c57c006 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -82,7 +82,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont // WriteAheadLogBackedBlockRDD else create simple BlockRDD. if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { val logSegments = blockStoreResults.map { - _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + _.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle }.toArray // Since storeInBlockManager = false, the storage level does not matter. new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 93caa4ba35c7f..ebdf418f4ab6a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -16,14 +16,17 @@ */ package org.apache.spark.streaming.rdd +import java.nio.ByteBuffer + import scala.reflect.ClassTag +import scala.util.control.NonFatal -import org.apache.hadoop.conf.Configuration +import org.apache.commons.io.FileUtils import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, WriteAheadLogRandomReader} +import org.apache.spark.streaming.util._ /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -31,26 +34,27 @@ import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, Wri * the segment of the write ahead log that backs the partition. * @param index index of the partition * @param blockId id of the block having the partition data - * @param segment segment of the write ahead log having the partition data + * @param walRecordHandle Handle of the record in a write ahead log having the partition data */ private[streaming] class WriteAheadLogBackedBlockRDDPartition( val index: Int, val blockId: BlockId, - val segment: WriteAheadLogFileSegment) + val walRecordHandle: WriteAheadLogRecordHandle) extends Partition /** * This class represents a special case of the BlockRDD where the data blocks in - * the block manager are also backed by segments in write ahead logs. For reading + * the block manager are also backed by data in write ahead logs. For reading * the data, this RDD first looks up the blocks by their ids in the block manager. - * If it does not find them, it looks up the corresponding file segment. + * If it does not find them, it looks up the corresponding data in the write ahead log. * * @param sc SparkContext * @param blockIds Ids of the blocks that contains this RDD's data - * @param segments Segments in write ahead logs that contain this RDD's data - * @param storeInBlockManager Whether to store in the block manager after reading from the segment + * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data + * @param storeInBlockManager Whether to store in the block manager after reading + * from the WAL record * @param storageLevel storage level to store when storing in block manager * (applicable when storeInBlockManager = true) */ @@ -58,15 +62,15 @@ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, @transient blockIds: Array[BlockId], - @transient segments: Array[WriteAheadLogFileSegment], + @transient walRecordHandles: Array[WriteAheadLogRecordHandle], storeInBlockManager: Boolean, storageLevel: StorageLevel) extends BlockRDD[T](sc, blockIds) { require( - blockIds.length == segments.length, + blockIds.length == walRecordHandles.length, s"Number of block ids (${blockIds.length}) must be " + - s"the same as number of segments (${segments.length}})!") + s"the same as number of WAL record handles (${walRecordHandles.length}})!") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -75,13 +79,13 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() Array.tabulate(blockIds.size) { i => - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) + new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), walRecordHandles(i)) } } /** * Gets the partition data by getting the corresponding block from the block manager. - * If the block does not exist, then the data is read from the corresponding segment + * If the block does not exist, then the data is read from the corresponding record * in write ahead log files. */ override def compute(split: Partition, context: TaskContext): Iterator[T] = { @@ -96,10 +100,35 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logDebug(s"Read partition data of $this from block manager, block $blockId") iterator case None => // Data not found in Block Manager, grab it from write ahead log file - val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf) - val dataRead = reader.read(partition.segment) - reader.close() - logInfo(s"Read partition data of $this from write ahead log, segment ${partition.segment}") + var dataRead: ByteBuffer = null + var writeAheadLog: WriteAheadLog = null + try { + // The WriteAheadLogUtils.createLog*** method needs a directory to create a + // WriteAheadLog object as the default FileBasedWriteAheadLog needs a directory for + // writing log data. However, the directory is not needed if data needs to be read, hence + // a dummy path is provided to satisfy the method parameter requirements. + // FileBasedWriteAheadLog will not create any file or directory at that path. + val dummyDirectory = FileUtils.getTempDirectoryPath() + writeAheadLog = WriteAheadLogUtils.createLogForReceiver( + SparkEnv.get.conf, dummyDirectory, hadoopConf) + dataRead = writeAheadLog.read(partition.walRecordHandle) + } catch { + case NonFatal(e) => + throw new SparkException( + s"Could not read data from write ahead log record ${partition.walRecordHandle}", e) + } finally { + if (writeAheadLog != null) { + writeAheadLog.close() + writeAheadLog = null + } + } + if (dataRead == null) { + throw new SparkException( + s"Could not read data from write ahead log record ${partition.walRecordHandle}, " + + s"read returned null") + } + logInfo(s"Read partition data of $this from write ahead log, record handle " + + partition.walRecordHandle) if (storeInBlockManager) { blockManager.putBytes(blockId, dataRead, storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") @@ -111,14 +140,20 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( /** * Get the preferred location of the partition. This returns the locations of the block - * if it is present in the block manager, else it returns the location of the - * corresponding segment in HDFS. + * if it is present in the block manager, else if FileBasedWriteAheadLogSegment is used, + * it returns the location of the corresponding file segment in HDFS . */ override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - blockLocations.getOrElse( - HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)) + blockLocations.getOrElse { + partition.walRecordHandle match { + case fileSegment: FileBasedWriteAheadLogSegment => + HdfsUtils.getFileSegmentLocations( + fileSegment.path, fileSegment.offset, fileSegment.length, hadoopConfig) + case _ => + Seq.empty + } + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index f4963a78e1d18..4bebcc5aa7ca0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -126,6 +126,20 @@ private[streaming] class BlockGenerator( listener.onAddData(data, metadata) } + /** + * Push multiple data items into the buffer. After buffering the data, the + * `BlockGeneratorListener.onAddData` callback will be called. All received data items + * will be periodically pushed into BlockManager. Note that all the data items is guaranteed + * to be present in a single block. + */ + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + dataIterator.foreach { data => + waitToPush() + currentBuffer += data + } + listener.onAddData(dataIterator, metadata) + } + /** Change the buffer to which single records are added to. */ private def updateCurrentBuffer(time: Long): Unit = synchronized { try { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 297bf04c0c25e..4b3d9ee4b0090 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.{existentials, postfixOps} -import WriteAheadLogBasedBlockHandler._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ -import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} +import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -96,7 +96,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, - segment: WriteAheadLogFileSegment + walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -116,10 +116,6 @@ private[streaming] class WriteAheadLogBasedBlockHandler( private val blockStoreTimeout = conf.getInt( "spark.streaming.receiver.blockStoreTimeout", 30).seconds - private val rollingInterval = conf.getInt( - "spark.streaming.receiver.writeAheadLog.rollingInterval", 60) - private val maxFailures = conf.getInt( - "spark.streaming.receiver.writeAheadLog.maxFailures", 3) private val effectiveStorageLevel = { if (storageLevel.deserialized) { @@ -139,13 +135,9 @@ private[streaming] class WriteAheadLogBasedBlockHandler( s"$effectiveStorageLevel when write ahead log is enabled") } - // Manages rolling log files - private val logManager = new WriteAheadLogManager( - checkpointDirToLogDir(checkpointDir, streamId), - hadoopConf, rollingInterval, maxFailures, - callerName = this.getClass.getSimpleName, - clock = clock - ) + // Write ahead log manages + private val writeAheadLog = WriteAheadLogUtils.createLogForReceiver( + conf, checkpointDirToLogDir(checkpointDir, streamId), hadoopConf) // For processing futures used in parallel block storing into block manager and write ahead log // # threads = 2, so that both writing to BM and WAL can proceed in parallel @@ -183,21 +175,21 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in write ahead log val storeInWriteAheadLogFuture = Future { - logManager.writeToLog(serializedBlock) + writeAheadLog.write(serializedBlock, clock.getTimeMillis()) } - // Combine the futures, wait for both to complete, and return the write ahead log segment + // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) - val segment = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, segment) + val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) + WriteAheadLogBasedStoreResult(blockId, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { - logManager.cleanupOldLogs(threshTime, waitForCompletion = false) + writeAheadLog.clean(threshTime, false) } def stop() { - logManager.stop() + writeAheadLog.close() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 89af40330b9d9..93f047b91018f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -25,12 +25,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.{Logging, SparkEnv, SparkException} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -46,7 +47,7 @@ private[streaming] class ReceiverSupervisorImpl( ) extends ReceiverSupervisor(receiver, env.conf) with Logging { private val receivedBlockHandler: ReceivedBlockHandler = { - if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { throw new SparkException( "Cannot enable receiver write-ahead log without checkpoint directory set. " + @@ -146,7 +147,7 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - trackerEndpoint.askWithReply[Boolean](AddBlock(blockInfo)) + trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } @@ -169,13 +170,13 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart() { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) - trackerEndpoint.askWithReply[Boolean](msg) + trackerEndpoint.askWithRetry[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - trackerEndpoint.askWithReply[Boolean](DeregisterReceiver(streamId, message, errorString)) + trackerEndpoint.askWithRetry[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 30cf87f5b7dd1..3c481bf3491f9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -25,15 +25,49 @@ import scala.util.Try */ private[streaming] class Job(val time: Time, func: () => _) { - var id: String = _ - var result: Try[_] = null + private var _id: String = _ + private var _outputOpId: Int = _ + private var isSet = false + private var _result: Try[_] = null def run() { - result = Try(func()) + _result = Try(func()) } - def setId(number: Int) { - id = "streaming job " + time + "." + number + def result: Try[_] = { + if (_result == null) { + throw new IllegalStateException("Cannot access result before job finishes") + } + _result + } + + /** + * @return the global unique id of this Job. + */ + def id: String = { + if (!isSet) { + throw new IllegalStateException("Cannot access id before calling setId") + } + _id + } + + /** + * @return the output op id of this Job. Each Job has a unique output op id in the same JobSet. + */ + def outputOpId: Int = { + if (!isSet) { + throw new IllegalStateException("Cannot access number before calling setId") + } + _outputOpId + } + + def setOutputOpId(outputOpId: Int) { + if (isSet) { + throw new IllegalStateException("Cannot call setOutputOpId more than once") + } + isSet = true + _id = s"streaming job $time.$outputOpId" + _outputOpId = outputOpId } override def toString: String = id diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 508b89278dcba..c7a2c1141a128 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -172,16 +172,28 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ssc.waiter.notifyError(e) } - private class JobHandler(job: Job) extends Runnable { + private class JobHandler(job: Job) extends Runnable with Logging { def run() { - eventLoop.post(JobStarted(job)) - // Disable checks for existing output directories in jobs launched by the streaming scheduler, - // since we may need to write output to an existing directory during checkpoint recovery; - // see SPARK-4835 for more details. - PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - job.run() + ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) + ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + try { + eventLoop.post(JobStarted(job)) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } + eventLoop.post(JobCompleted(job)) + } finally { + ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) + ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) } - eventLoop.post(JobCompleted(job)) } } } + +private[streaming] object JobScheduler { + val BATCH_TIME_PROPERTY_KEY = "spark.streaming.internal.batchTime" + val OUTPUT_OP_ID_PROPERTY_KEY = "spark.streaming.internal.outputOpId" +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 5b134877d0b2d..24b3794236ea5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -35,7 +35,7 @@ case class JobSet( private var processingStartTime = -1L // when the first job of this jobset started processing private var processingEndTime = -1L // when the last job of this jobset finished processing - jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } + jobs.zipWithIndex.foreach { case (job, i) => job.setOutputOpId(i) } incompleteJobs ++= jobs def handleJobStart(job: Job) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 200cf4ef4b0f1..14e769a281f51 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -25,10 +25,10 @@ import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.WriteAheadLogManager +import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -70,7 +70,7 @@ private[streaming] class ReceivedBlockTracker( private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] - private val logManagerOption = createLogManager() + private val writeAheadLogOption = createWriteAheadLog() private var lastAllocatedBatchTime: Time = null @@ -155,12 +155,12 @@ private[streaming] class ReceivedBlockTracker( logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup - logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) } /** Stop the block tracker. */ def stop() { - logManagerOption.foreach { _.stop() } + writeAheadLogOption.foreach { _.close() } } /** @@ -190,9 +190,10 @@ private[streaming] class ReceivedBlockTracker( timeToAllocatedBlocks --= batchTimes } - logManagerOption.foreach { logManager => + writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - logManager.readFromLog().foreach { byteBuffer => + import scala.collection.JavaConversions._ + writeAheadLog.readAll().foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { case BlockAdditionEvent(receivedBlockInfo) => @@ -208,10 +209,10 @@ private[streaming] class ReceivedBlockTracker( /** Write an update to the tracker to the write ahead log */ private def writeToLog(record: ReceivedBlockTrackerLogEvent) { - if (isLogManagerEnabled) { + if (isWriteAheadLogEnabled) { logDebug(s"Writing to log $record") - logManagerOption.foreach { logManager => - logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + writeAheadLogOption.foreach { logManager => + logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) } } } @@ -222,8 +223,8 @@ private[streaming] class ReceivedBlockTracker( } /** Optionally create the write ahead log manager only if the feature is enabled */ - private def createLogManager(): Option[WriteAheadLogManager] = { - if (conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + private def createWriteAheadLog(): Option[WriteAheadLog] = { + if (WriteAheadLogUtils.enableReceiverLog(conf)) { if (checkpointDirOption.isEmpty) { throw new SparkException( "Cannot enable receiver write-ahead log without checkpoint directory set. " + @@ -231,19 +232,16 @@ private[streaming] class ReceivedBlockTracker( "See documentation for more details.") } val logDir = ReceivedBlockTracker.checkpointDirToLogDir(checkpointDirOption.get) - val rollingIntervalSecs = conf.getInt( - "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) - val logManager = new WriteAheadLogManager(logDir, hadoopConf, - rollingIntervalSecs = rollingIntervalSecs, clock = clock, - callerName = "ReceivedBlockHandlerMaster") - Some(logManager) + + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + Some(log) } else { None } } - /** Check if the log manager is enabled. This is only used for testing purposes. */ - private[streaming] def isLogManagerEnabled: Boolean = logManagerOption.nonEmpty + /** Check if the write ahead log is enabled. This is only used for testing purposes. */ + private[streaming] def isWriteAheadLogEnabled: Boolean = writeAheadLogOption.nonEmpty } private[streaming] object ReceivedBlockTracker { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index c4ead6f30a63d..1af65716d3003 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials +import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} @@ -125,7 +126,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) // Signal the receivers to delete old block data - if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") receiverInfo.values.flatMap { info => Option(info.endpoint) } .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index df1c0a10704c3..e219e27785533 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -19,7 +19,6 @@ package org.apache.spark.streaming.ui import scala.xml.Node -import org.apache.spark.streaming.scheduler.BatchInfo import org.apache.spark.ui.UIUtils private[ui] abstract class BatchTableBase(tableId: String) { @@ -31,18 +30,20 @@ private[ui] abstract class BatchTableBase(tableId: String) { } - protected def baseRow(batch: BatchInfo): Seq[Node] = { + protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds val formattedBatchTime = UIUtils.formatDate(batch.batchTime.milliseconds) - val eventCount = batch.receivedBlockInfo.values.map { - receivers => receivers.map(_.numRecords).sum - }.sum + val eventCount = batch.numRecords val schedulingDelay = batch.schedulingDelay val formattedSchedulingDelay = schedulingDelay.map(UIUtils.formatDuration).getOrElse("-") val processingTime = batch.processingDelay val formattedProcessingTime = processingTime.map(UIUtils.formatDuration).getOrElse("-") - + @@ -85,16 +87,16 @@ private[ui] class ActiveBatchTable(runningBatches: Seq[BatchInfo], waitingBatche runningBatches.flatMap(batch => {runningBatchRow(batch)}) } - private def runningBatchRow(batch: BatchInfo): Seq[Node] = { + private def runningBatchRow(batch: BatchUIData): Seq[Node] = { baseRow(batch) ++ } - private def waitingBatchRow(batch: BatchInfo): Seq[Node] = { + private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { baseRow(batch) ++ } } -private[ui] class CompletedBatchTable(batches: Seq[BatchInfo]) +private[ui] class CompletedBatchTable(batches: Seq[BatchUIData]) extends BatchTableBase("completed-batches-table") { override protected def columns: Seq[Node] = super.columns ++ @@ -103,7 +105,7 @@ private[ui] class CompletedBatchTable(batches: Seq[BatchInfo]) batches.flatMap(batch => {completedBatchRow(batch)}) } - private def completedBatchRow(batch: BatchInfo): Seq[Node] = { + private def completedBatchRow(batch: BatchUIData): Seq[Node] = { val totalDelay = batch.totalDelay val formattedTotalDelay = totalDelay.map(UIUtils.formatDuration).getOrElse("-") baseRow(batch) ++ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala new file mode 100644 index 0000000000000..2da9a29e2529e --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.streaming.Time +import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.streaming.ui.StreamingJobProgressListener.{SparkJobId, OutputOpId} +import org.apache.spark.ui.jobs.UIData.JobUIData + + +private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { + private val streamingListener = parent.listener + private val sparkListener = parent.ssc.sc.jobProgressListener + + private def columns: Seq[Node] = { + + + + + + + + + } + + /** + * Generate a row for a Spark Job. Because duplicated output op infos needs to be collapsed into + * one cell, we use "rowspan" for the first row of a output op. + */ + def generateJobRow( + outputOpId: OutputOpId, + formattedOutputOpDuration: String, + numSparkJobRowsInOutputOp: Int, + isFirstRow: Boolean, + sparkJob: JobUIData): Seq[Node] = { + val lastStageInfo = Option(sparkJob.stageIds) + .filter(_.nonEmpty) + .flatMap { ids => sparkListener.stageIdToInfo.get(ids.max) } + val lastStageData = lastStageInfo.flatMap { s => + sparkListener.stageIdToData.get((s.stageId, s.attemptId)) + } + + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") + val duration: Option[Long] = { + sparkJob.submissionTime.map { start => + val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) + end - start + } + } + val lastFailureReason = + sparkJob.stageIds.sorted.reverse.flatMap(sparkListener.stageIdToInfo.get). + dropWhile(_.failureReason == None).take(1). // get the first info that contains failure + flatMap(info => info.failureReason).headOption.getOrElse("") + val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("-") + val detailUrl = s"${UIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + + // In the first row, output op id and its information needs to be shown. In other rows, these + // cells will be taken up due to "rowspan". + // scalastyle:off + val prefixCells = + if (isFirstRow) { + + + + } else { + Nil + } + // scalastyle:on + + + {prefixCells} + + + + + {failureReasonCell(lastFailureReason)} + + } + + private def generateOutputOpIdRow( + outputOpId: OutputOpId, sparkJobs: Seq[JobUIData]): Seq[Node] = { + val sparkjobDurations = sparkJobs.map(sparkJob => { + sparkJob.submissionTime.map { start => + val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) + end - start + } + }) + val formattedOutputOpDuration = + if (sparkjobDurations.exists(_ == None)) { + // If any job does not finish, set "formattedOutputOpDuration" to "-" + "-" + } else { + UIUtils.formatDuration(sparkjobDurations.flatMap(x => x).sum) + } + generateJobRow(outputOpId, formattedOutputOpDuration, sparkJobs.size, true, sparkJobs.head) ++ + sparkJobs.tail.map { sparkJob => + generateJobRow(outputOpId, formattedOutputOpDuration, sparkJobs.size, false, sparkJob) + }.flatMap(x => x) + } + + private def failureReasonCell(failureReason: String): Seq[Node] = { + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + + } + + private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { + sparkListener.activeJobs.get(sparkJobId).orElse { + sparkListener.completedJobs.find(_.jobId == sparkJobId).orElse { + sparkListener.failedJobs.find(_.jobId == sparkJobId) + } + } + } + + /** + * Generate the job table for the batch. + */ + private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId).toSeq. + sortBy(_._1). // sorted by OutputOpId + map { case (outputOpId, outputOpIdAndSparkJobIds) => + // sort SparkJobIds for each OutputOpId + (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) + } + sparkListener.synchronized { + val outputOpIdWithJobs: Seq[(OutputOpId, Seq[JobUIData])] = + outputOpIdToSparkJobIds.map { case (outputOpId, sparkJobIds) => + // Filter out spark Job ids that don't exist in sparkListener + (outputOpId, sparkJobIds.flatMap(getJobData)) + } + +
Property NameDefaultMeaning
spark.reducer.maxMbInFlight48spark.reducer.maxSizeInFlight48m - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + Maximum size of map outputs to fetch simultaneously from each reduce task. Since each output requires us to create a buffer to receive it, this represents a fixed memory overhead per reduce task, so keep it small unless you have a large amount of memory.
spark.shuffle.file.buffer.kb32spark.shuffle.file.buffer32k - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + Size of the in-memory buffer for each shuffle file output stream. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
spark.io.compression.lz4.block.size32768spark.io.compression.lz4.blockSize32k - Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + Block size used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
spark.io.compression.snappy.block.size32768spark.io.compression.snappy.blockSize32k - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + Block size used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
spark.kryoserializer.buffer.max.mb64spark.kryoserializer.buffer.max64m - Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + Maximum allowable size of Kryo serialization buffer. This must be larger than any object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception inside Kryo.
spark.kryoserializer.buffer.mb0.064spark.kryoserializer.buffer64k - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to spark.kryoserializer.buffer.max.mb if needed.
Property NameDefaultMeaning
spark.broadcast.blockSize40964m - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
spark.storage.memoryMapThreshold20971522m - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + Size of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system. Processing Time{formattedBatchTime} + + {formattedBatchTime} + + {eventCount.toString} events {formattedSchedulingDelay} @@ -73,8 +74,9 @@ private[ui] abstract class BatchTableBase(tableId: String) { protected def renderRows: Seq[Node] } -private[ui] class ActiveBatchTable(runningBatches: Seq[BatchInfo], waitingBatches: Seq[BatchInfo]) - extends BatchTableBase("active-batches-table") { +private[ui] class ActiveBatchTable( + runningBatches: Seq[BatchUIData], + waitingBatches: Seq[BatchUIData]) extends BatchTableBase("active-batches-table") { override protected def columns: Seq[Node] = super.columns ++ Status
processingqueuedTotal Delay
Output Op IdDescriptionDurationJob IdDurationStages: Succeeded/TotalTasks (for all stages): Succeeded/TotalError{outputOpId.toString} + + {lastStageDescription} + {lastStageName} + {formattedOutputOpDuration}
+ + {sparkJob.jobId}{sparkJob.jobGroup.map(id => s"($id)").getOrElse("")} + + + {formattedDuration} + + {sparkJob.completedStageIndices.size}/{sparkJob.stageIds.size - sparkJob.numSkippedStages} + {if (sparkJob.numFailedStages > 0) s"(${sparkJob.numFailedStages} failed)"} + {if (sparkJob.numSkippedStages > 0) s"(${sparkJob.numSkippedStages} skipped)"} + + { + UIUtils.makeProgressBar( + started = sparkJob.numActiveTasks, + completed = sparkJob.numCompletedTasks, + failed = sparkJob.numFailedTasks, + skipped = sparkJob.numSkippedTasks, + total = sparkJob.numTasks - sparkJob.numSkippedTasks) + } +
{failureReasonSummary}{details}
+ + {columns} + + + { + outputOpIdWithJobs.map { + case (outputOpId, jobs) => generateOutputOpIdRow(outputOpId, jobs) + } + } + +
+ } + } + + def render(request: HttpServletRequest): Seq[Node] = { + val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { + throw new IllegalArgumentException(s"Missing id parameter") + } + val formattedBatchTime = UIUtils.formatDate(batchTime.milliseconds) + + val batchUIData = streamingListener.getBatchUIData(batchTime).getOrElse { + throw new IllegalArgumentException(s"Batch $formattedBatchTime does not exist") + } + + val formattedSchedulingDelay = + batchUIData.schedulingDelay.map(UIUtils.formatDuration).getOrElse("-") + val formattedProcessingTime = + batchUIData.processingDelay.map(UIUtils.formatDuration).getOrElse("-") + val formattedTotalDelay = batchUIData.totalDelay.map(UIUtils.formatDuration).getOrElse("-") + + val summary: NodeSeq = +
+
    +
  • + Batch Duration: + {UIUtils.formatDuration(streamingListener.batchDuration)} +
  • +
  • + Input data size: + {batchUIData.numRecords} records +
  • +
  • + Scheduling delay: + {formattedSchedulingDelay} +
  • +
  • + Processing time: + {formattedProcessingTime} +
  • +
  • + Total delay: + {formattedTotalDelay} +
  • +
+
+ + val jobTable = + if (batchUIData.outputOpIdSparkJobIdPairs.isEmpty) { +
Cannot find any job for Batch {formattedBatchTime}.
+ } else { + generateJobTable(batchUIData) + } + + val content = summary ++ jobTable + + UIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala new file mode 100644 index 0000000000000..f45c291b7c0fe --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming.ui + +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.streaming.ui.StreamingJobProgressListener._ + +private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) + +private[ui] case class BatchUIData( + val batchTime: Time, + val receiverNumRecords: Map[Int, Long], + val submissionTime: Long, + val processingStartTime: Option[Long], + val processingEndTime: Option[Long], + var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { + + /** + * Time taken for the first job of this batch to start processing from the time this batch + * was submitted to the streaming scheduler. Essentially, it is + * `processingStartTime` - `submissionTime`. + */ + def schedulingDelay: Option[Long] = processingStartTime.map(_ - submissionTime) + + /** + * Time taken for the all jobs of this batch to finish processing from the time they started + * processing. Essentially, it is `processingEndTime` - `processingStartTime`. + */ + def processingDelay: Option[Long] = { + for (start <- processingStartTime; + end <- processingEndTime) + yield end - start + } + + /** + * Time taken for all the jobs of this batch to finish processing from the time they + * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + */ + def totalDelay: Option[Long] = processingEndTime.map(_ - submissionTime) + + /** + * The number of recorders received by the receivers in this batch. + */ + def numRecords: Long = receiverNumRecords.map(_._2).sum +} + +private[ui] object BatchUIData { + + def apply(batchInfo: BatchInfo): BatchUIData = { + new BatchUIData( + batchInfo.batchTime, + batchInfo.receivedBlockInfo.mapValues(_.map(_.numRecords).sum), + batchInfo.submissionTime, + batchInfo.processingStartTime, + batchInfo.processingEndTime + ) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index be1e8686cf9fa..34b55717a1db2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,29 +17,58 @@ package org.apache.spark.streaming.ui -import scala.collection.mutable.{Queue, HashMap} +import java.util.LinkedHashMap +import java.util.{Map => JMap} +import java.util.Properties +import scala.collection.mutable.{ArrayBuffer, Queue, HashMap, SynchronizedBuffer} + +import org.apache.spark.scheduler._ import org.apache.spark.streaming.{Time, StreamingContext} import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted -import org.apache.spark.streaming.scheduler.BatchInfo import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted import org.apache.spark.util.Distribution private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) - extends StreamingListener { + extends StreamingListener with SparkListener { - private val waitingBatchInfos = new HashMap[Time, BatchInfo] - private val runningBatchInfos = new HashMap[Time, BatchInfo] - private val completedBatchInfos = new Queue[BatchInfo] - private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + private val waitingBatchUIData = new HashMap[Time, BatchUIData] + private val runningBatchUIData = new HashMap[Time, BatchUIData] + private val completedBatchUIData = new Queue[BatchUIData] + private val batchUIDataLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) private var totalCompletedBatches = 0L private var totalReceivedRecords = 0L private var totalProcessedRecords = 0L private val receiverInfos = new HashMap[Int, ReceiverInfo] + // Because onJobStart and onBatchXXX messages are processed in different threads, + // we may not be able to get the corresponding BatchUIData when receiving onJobStart. So here we + // cannot use a map of (Time, BatchUIData). + private[ui] val batchTimeToOutputOpIdSparkJobIdPair = + new LinkedHashMap[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]] { + override def removeEldestEntry( + p1: JMap.Entry[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]]): Boolean = { + // If a lot of "onBatchCompleted"s happen before "onJobStart" (image if + // SparkContext.listenerBus is very slow), "batchTimeToOutputOpIdToSparkJobIds" + // may add some information for a removed batch when processing "onJobStart". It will be a + // memory leak. + // + // To avoid the memory leak, we control the size of "batchTimeToOutputOpIdToSparkJobIds" and + // evict the eldest one. + // + // Note: if "onJobStart" happens before "onBatchSubmitted", the size of + // "batchTimeToOutputOpIdToSparkJobIds" may be greater than the number of the retained + // batches temporarily, so here we use "10" to handle such case. This is not a perfect + // solution, but at least it can handle most of cases. + size() > + waitingBatchUIData.size + runningBatchUIData.size + completedBatchUIData.size + 10 + } + } + + val batchDuration = ssc.graph.batchDuration.milliseconds override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { @@ -62,37 +91,62 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { synchronized { - waitingBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + waitingBatchUIData(batchSubmitted.batchInfo.batchTime) = + BatchUIData(batchSubmitted.batchInfo) } } override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { - runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo - waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) + val batchUIData = BatchUIData(batchStarted.batchInfo) + runningBatchUIData(batchStarted.batchInfo.batchTime) = BatchUIData(batchStarted.batchInfo) + waitingBatchUIData.remove(batchStarted.batchInfo.batchTime) - batchStarted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalReceivedRecords += infos.map(_.numRecords).sum - } + totalReceivedRecords += batchUIData.numRecords } override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { synchronized { - waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) - runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedBatchInfos.size > batchInfoLimit) completedBatchInfos.dequeue() + waitingBatchUIData.remove(batchCompleted.batchInfo.batchTime) + runningBatchUIData.remove(batchCompleted.batchInfo.batchTime) + val batchUIData = BatchUIData(batchCompleted.batchInfo) + completedBatchUIData.enqueue(batchUIData) + if (completedBatchUIData.size > batchUIDataLimit) { + val removedBatch = completedBatchUIData.dequeue() + batchTimeToOutputOpIdSparkJobIdPair.remove(removedBatch.batchTime) + } totalCompletedBatches += 1L - batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalProcessedRecords += infos.map(_.numRecords).sum + totalProcessedRecords += batchUIData.numRecords + } + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { + getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => + var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) + if (outputOpIdToSparkJobIds == null) { + outputOpIdToSparkJobIds = + new ArrayBuffer[OutputOpIdAndSparkJobId]() + with SynchronizedBuffer[OutputOpIdAndSparkJobId] + batchTimeToOutputOpIdSparkJobIdPair.put(batchTime, outputOpIdToSparkJobIds) } + outputOpIdToSparkJobIds += OutputOpIdAndSparkJobId(outputOpId, jobStart.jobId) } } - def numReceivers: Int = synchronized { - ssc.graph.getReceiverInputStreams().size + private def getBatchTimeAndOutputOpId(properties: Properties): Option[(Time, Int)] = { + val batchTime = properties.getProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY) + if (batchTime == null) { + // Not submitted from JobScheduler + None + } else { + val outputOpId = properties.getProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY) + assert(outputOpId != null) + Some(Time(batchTime.toLong) -> outputOpId.toInt) + } } + def numReceivers: Int = ssc.graph.getReceiverInputStreams().size + def numTotalCompletedBatches: Long = synchronized { totalCompletedBatches } @@ -106,19 +160,19 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def numUnprocessedBatches: Long = synchronized { - waitingBatchInfos.size + runningBatchInfos.size + waitingBatchUIData.size + runningBatchUIData.size } - def waitingBatches: Seq[BatchInfo] = synchronized { - waitingBatchInfos.values.toSeq + def waitingBatches: Seq[BatchUIData] = synchronized { + waitingBatchUIData.values.toSeq } - def runningBatches: Seq[BatchInfo] = synchronized { - runningBatchInfos.values.toSeq + def runningBatches: Seq[BatchUIData] = synchronized { + runningBatchUIData.values.toSeq } - def retainedCompletedBatches: Seq[BatchInfo] = synchronized { - completedBatchInfos.toSeq + def retainedCompletedBatches: Seq[BatchUIData] = synchronized { + completedBatchUIData.toSeq } def processingDelayDistribution: Option[Distribution] = synchronized { @@ -134,15 +188,11 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def receivedRecordsDistributions: Map[Int, Option[Distribution]] = synchronized { - val latestBatchInfos = retainedBatches.reverse.take(batchInfoLimit) - val latestBlockInfos = latestBatchInfos.map(_.receivedBlockInfo) + val latestBatches = retainedBatches.reverse.take(batchUIDataLimit) (0 until numReceivers).map { receiverId => - val blockInfoOfParticularReceiver = latestBlockInfos.map { batchInfo => - batchInfo.get(receiverId).getOrElse(Array.empty) - } - val recordsOfParticularReceiver = blockInfoOfParticularReceiver.map { blockInfo => - // calculate records per second for each batch - blockInfo.map(_.numRecords).sum.toDouble * 1000 / batchDuration + val recordsOfParticularReceiver = latestBatches.map { batch => + // calculate records per second for each batch + batch.receiverNumRecords.get(receiverId).sum.toDouble * 1000 / batchDuration } val distributionOption = Distribution(recordsOfParticularReceiver) (receiverId, distributionOption) @@ -150,10 +200,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def lastReceivedBatchRecords: Map[Int, Long] = synchronized { - val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) + val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receiverNumRecords) lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => (0 until numReceivers).map { receiverId => - (receiverId, lastReceivedBlockInfo(receiverId).map(_.numRecords).sum) + (receiverId, lastReceivedBlockInfo.getOrElse(receiverId, 0L)) }.toMap }.getOrElse { (0 until numReceivers).map(receiverId => (receiverId, 0L)).toMap @@ -164,20 +214,39 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) receiverInfos.get(receiverId) } - def lastCompletedBatch: Option[BatchInfo] = synchronized { - completedBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + def lastCompletedBatch: Option[BatchUIData] = synchronized { + completedBatchUIData.sortBy(_.batchTime)(Time.ordering).lastOption } - def lastReceivedBatch: Option[BatchInfo] = synchronized { + def lastReceivedBatch: Option[BatchUIData] = synchronized { retainedBatches.lastOption } - private def retainedBatches: Seq[BatchInfo] = { - (waitingBatchInfos.values.toSeq ++ - runningBatchInfos.values.toSeq ++ completedBatchInfos).sortBy(_.batchTime)(Time.ordering) + private def retainedBatches: Seq[BatchUIData] = { + (waitingBatchUIData.values.toSeq ++ + runningBatchUIData.values.toSeq ++ completedBatchUIData).sortBy(_.batchTime)(Time.ordering) + } + + private def extractDistribution(getMetric: BatchUIData => Option[Long]): Option[Distribution] = { + Distribution(completedBatchUIData.flatMap(getMetric(_)).map(_.toDouble)) } - private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { - Distribution(completedBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + def getBatchUIData(batchTime: Time): Option[BatchUIData] = synchronized { + val batchUIData = waitingBatchUIData.get(batchTime).orElse { + runningBatchUIData.get(batchTime).orElse { + completedBatchUIData.find(batch => batch.batchTime == batchTime) + } + } + batchUIData.foreach { _batchUIData => + val outputOpIdToSparkJobIds = + Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime)).getOrElse(Seq.empty) + _batchUIData.outputOpIdSparkJobIdPairs = outputOpIdToSparkJobIds + } + batchUIData } } + +private[streaming] object StreamingJobProgressListener { + type SparkJobId = Int + type OutputOpId = Int +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9a860ea4a6c68..e4039639adbad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -27,14 +27,16 @@ import StreamingTab._ * Spark Web UI tab that shows statistics of a streaming job. * This assumes the given SparkContext has enabled its SparkUI. */ -private[spark] class StreamingTab(ssc: StreamingContext) +private[spark] class StreamingTab(val ssc: StreamingContext) extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { val parent = getSparkUI(ssc) val listener = ssc.progressListener ssc.addStreamingListener(listener) + ssc.sc.addSparkListener(listener) attachPage(new StreamingPage(this)) + attachPage(new BatchPage(this)) parent.attachTab(this) def detach() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala similarity index 79% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 38a93cc3c9a1f..9985fedc35141 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer +import java.util.{Iterator => JIterator} import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} @@ -24,9 +25,9 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.Logging -import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} -import WriteAheadLogManager._ + +import org.apache.spark.util.ThreadUtils +import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. @@ -34,37 +35,32 @@ import WriteAheadLogManager._ * - Recovers the log files and the reads the recovered records upon failures. * - Cleans up old log files. * - * Uses [[org.apache.spark.streaming.util.WriteAheadLogWriter]] to write - * and [[org.apache.spark.streaming.util.WriteAheadLogReader]] to read. + * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write + * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. * * @param logDirectory Directory when rotating log files will be created. * @param hadoopConf Hadoop configuration for reading/writing log files. - * @param rollingIntervalSecs The interval in seconds with which logs will be rolled over. - * Default is one minute. - * @param maxFailures Max number of failures that is tolerated for every attempt to write to log. - * Default is three. - * @param callerName Optional name of the class who is using this manager. - * @param clock Optional clock that is used to check for rotation interval. */ -private[streaming] class WriteAheadLogManager( +private[streaming] class FileBasedWriteAheadLog( + conf: SparkConf, logDirectory: String, hadoopConf: Configuration, - rollingIntervalSecs: Int = 60, - maxFailures: Int = 3, - callerName: String = "", - clock: Clock = new SystemClock - ) extends Logging { + rollingIntervalSecs: Int, + maxFailures: Int + ) extends WriteAheadLog with Logging { + + import FileBasedWriteAheadLog._ private val pastLogs = new ArrayBuffer[LogInfo] - private val callerNameTag = - if (callerName.nonEmpty) s" for $callerName" else "" + private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") + private val threadpoolName = s"WriteAheadLogManager $callerNameTag" implicit private val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None - private var currentLogWriter: WriteAheadLogWriter = null + private var currentLogWriter: FileBasedWriteAheadLogWriter = null private var currentLogWriterStartTime: Long = -1L private var currentLogWriterStopTime: Long = -1L @@ -75,14 +71,14 @@ private[streaming] class WriteAheadLogManager( * ByteBuffer to HDFS. When this method returns, the data is guaranteed to have been flushed * to HDFS, and will be available for readers to read. */ - def writeToLog(byteBuffer: ByteBuffer): WriteAheadLogFileSegment = synchronized { - var fileSegment: WriteAheadLogFileSegment = null + def write(byteBuffer: ByteBuffer, time: Long): FileBasedWriteAheadLogSegment = synchronized { + var fileSegment: FileBasedWriteAheadLogSegment = null var failures = 0 var lastException: Exception = null var succeeded = false while (!succeeded && failures < maxFailures) { try { - fileSegment = getLogWriter(clock.getTimeMillis()).write(byteBuffer) + fileSegment = getLogWriter(time).write(byteBuffer) succeeded = true } catch { case ex: Exception => @@ -99,6 +95,19 @@ private[streaming] class WriteAheadLogManager( fileSegment } + def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + val fileSegment = segment.asInstanceOf[FileBasedWriteAheadLogSegment] + var reader: FileBasedWriteAheadLogRandomReader = null + var byteBuffer: ByteBuffer = null + try { + reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) + byteBuffer = reader.read(fileSegment) + } finally { + reader.close() + } + byteBuffer + } + /** * Read all the existing logs from the log directory. * @@ -108,12 +117,14 @@ private[streaming] class WriteAheadLogManager( * the latest the records. This does not deal with currently active log files, and * hence the implementation is kept simple. */ - def readFromLog(): Iterator[ByteBuffer] = synchronized { + def readAll(): JIterator[ByteBuffer] = synchronized { + import scala.collection.JavaConversions._ val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) + logFilesToRead.iterator.map { file => logDebug(s"Creating log reader with $file") - new WriteAheadLogReader(file, hadoopConf) + new FileBasedWriteAheadLogReader(file, hadoopConf) } flatMap { x => x } } @@ -129,7 +140,7 @@ private[streaming] class WriteAheadLogManager( * deleted. This should be set to true only for testing. Else the files will be deleted * asynchronously. */ - def cleanupOldLogs(threshTime: Long, waitForCompletion: Boolean): Unit = { + def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") @@ -160,7 +171,7 @@ private[streaming] class WriteAheadLogManager( /** Stop the manager, close any open log writer */ - def stop(): Unit = synchronized { + def close(): Unit = synchronized { if (currentLogWriter != null) { currentLogWriter.close() } @@ -169,7 +180,7 @@ private[streaming] class WriteAheadLogManager( } /** Get the current log writer while taking care of rotation */ - private def getLogWriter(currentTime: Long): WriteAheadLogWriter = synchronized { + private def getLogWriter(currentTime: Long): FileBasedWriteAheadLogWriter = synchronized { if (currentLogWriter == null || currentTime > currentLogWriterStopTime) { resetWriter() currentLogPath.foreach { @@ -180,7 +191,7 @@ private[streaming] class WriteAheadLogManager( val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) - currentLogWriter = new WriteAheadLogWriter(currentLogPath.get, hadoopConf) + currentLogWriter = new FileBasedWriteAheadLogWriter(currentLogPath.get, hadoopConf) } currentLogWriter } @@ -207,7 +218,7 @@ private[streaming] class WriteAheadLogManager( } } -private[util] object WriteAheadLogManager { +private[streaming] object FileBasedWriteAheadLog { case class LogInfo(startTime: Long, endTime: Long, path: String) @@ -217,6 +228,11 @@ private[util] object WriteAheadLogManager { s"log-$startTime-$stopTime" } + def getCallerName(): Option[String] = { + val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + } + /** Convert a sequence of files to a sequence of sorted LogInfo objects */ def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = { files.flatMap { file => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala similarity index 83% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala index 003989092a42a..f7168229ec15a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala @@ -23,16 +23,16 @@ import org.apache.hadoop.conf.Configuration /** * A random access reader for reading write ahead log files written using - * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. Given the file segment info, - * this reads the record (bytebuffer) from the log file. + * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. Given the file segment info, + * this reads the record (ByteBuffer) from the log file. */ -private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configuration) +private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: Configuration) extends Closeable { private val instream = HdfsUtils.getInputStream(path, conf) private var closed = false - def read(segment: WriteAheadLogFileSegment): ByteBuffer = synchronized { + def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized { assertOpen() instream.seek(segment.offset) val nextLength = instream.readInt() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala similarity index 93% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index 2afc0d1551acf..c3bb59f3fef94 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -24,11 +24,11 @@ import org.apache.spark.Logging /** * A reader for reading write ahead log files written using - * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. This reads + * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. This reads * the records (bytebuffers) in the log file sequentially and return them as an * iterator of bytebuffers. */ -private[streaming] class WriteAheadLogReader(path: String, conf: Configuration) +private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Configuration) extends Iterator[ByteBuffer] with Closeable with Logging { private val instream = HdfsUtils.getInputStream(path, conf) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala similarity index 86% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala index 1005a2c8ec303..2e1f1528fad20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala @@ -17,4 +17,5 @@ package org.apache.spark.streaming.util /** Class for representing a segment of data in a write ahead log file */ -private[streaming] case class WriteAheadLogFileSegment (path: String, offset: Long, length: Int) +private[streaming] case class FileBasedWriteAheadLogSegment(path: String, offset: Long, length: Int) + extends WriteAheadLogRecordHandle diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala similarity index 88% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index 679f6a6dfd7c1..e146bec32a456 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -17,18 +17,17 @@ package org.apache.spark.streaming.util import java.io._ -import java.net.URI import java.nio.ByteBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataOutputStream, FileSystem} +import org.apache.hadoop.fs.FSDataOutputStream /** * A writer for writing byte-buffers to a write ahead log file. */ -private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configuration) +private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: Configuration) extends Closeable { private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) @@ -43,11 +42,11 @@ private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configura private var closed = false /** Write the bytebuffer to the log file */ - def write(data: ByteBuffer): WriteAheadLogFileSegment = synchronized { + def write(data: ByteBuffer): FileBasedWriteAheadLogSegment = synchronized { assertOpen() data.rewind() // Rewind to ensure all data in the buffer is retrieved val lengthToWrite = data.remaining() - val segment = new WriteAheadLogFileSegment(path, nextOffset, lengthToWrite) + val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) if (data.hasArray) { stream.write(data.array()) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala new file mode 100644 index 0000000000000..7f6ff12c58d47 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkException} + +/** A helper class with utility functions related to the WriteAheadLog interface */ +private[streaming] object WriteAheadLogUtils extends Logging { + val RECEIVER_WAL_ENABLE_CONF_KEY = "spark.streaming.receiver.writeAheadLog.enable" + val RECEIVER_WAL_CLASS_CONF_KEY = "spark.streaming.receiver.writeAheadLog.class" + val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY = + "spark.streaming.receiver.writeAheadLog.rollingIntervalSecs" + val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures" + + val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class" + val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = + "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" + val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + + val DEFAULT_ROLLING_INTERVAL_SECS = 60 + val DEFAULT_MAX_FAILURES = 3 + + def enableReceiverLog(conf: SparkConf): Boolean = { + conf.getBoolean(RECEIVER_WAL_ENABLE_CONF_KEY, false) + } + + def getRollingIntervalSecs(conf: SparkConf, isDriver: Boolean): Int = { + if (isDriver) { + conf.getInt(DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS) + } else { + conf.getInt(RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS) + } + } + + def getMaxFailures(conf: SparkConf, isDriver: Boolean): Int = { + if (isDriver) { + conf.getInt(DRIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES) + } else { + conf.getInt(RECEIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES) + } + } + + /** + * Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try + * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. + */ + def createLogForDriver( + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + createLog(true, sparkConf, fileWalLogDirectory, fileWalHadoopConf) + } + + /** + * Create a WriteAheadLog for the receiver. If configured with custom WAL class, it will try + * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. + */ + def createLogForReceiver( + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + createLog(false, sparkConf, fileWalLogDirectory, fileWalHadoopConf) + } + + /** + * Create a WriteAheadLog based on the value of the given config key. The config key is used + * to get the class name from the SparkConf. If the class is configured, it will try to + * create instance of that class by first trying `new CustomWAL(sparkConf, logDir)` then trying + * `new CustomWAL(sparkConf)`. If either fails, it will fail. If no class is configured, then + * it will create the default FileBasedWriteAheadLog. + */ + private def createLog( + isDriver: Boolean, + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + + val classNameOption = if (isDriver) { + sparkConf.getOption(DRIVER_WAL_CLASS_CONF_KEY) + } else { + sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) + } + classNameOption.map { className => + try { + instantiateClass( + Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) + } catch { + case NonFatal(e) => + throw new SparkException(s"Could not create a write ahead log of class $className", e) + } + }.getOrElse { + new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf, + getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver)) + } + } + + /** Instantiate the class, either using single arg constructor or zero arg constructor */ + private def instantiateClass(cls: Class[_ <: WriteAheadLog], conf: SparkConf): WriteAheadLog = { + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf) + } catch { + case nsme: NoSuchMethodException => + cls.getConstructor().newInstance() + } + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 90340753a4eed..b1adf881dd0f5 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -21,11 +21,13 @@ import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; + import scala.Tuple2; import org.junit.Assert; @@ -45,6 +47,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; +import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -929,7 +932,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -987,12 +990,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1123,7 +1126,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1144,14 +1147,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1249,17 +1252,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1292,17 +1295,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1328,7 +1331,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1707,6 +1710,74 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } + @SuppressWarnings("unchecked") + @Test + public void testContextGetOrCreate() throws InterruptedException { + + final SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("newContext", "true"); + + File emptyDir = Files.createTempDir(); + emptyDir.deleteOnExit(); + StreamingContextSuite contextSuite = new StreamingContextSuite(); + String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); + String checkpointDir = contextSuite.createValidCheckpoint(); + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + final AtomicBoolean newContextCreated = new AtomicBoolean(false); + Function0 creatingFunc = new Function0() { + public JavaStreamingContext call() { + newContextCreated.set(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); + } + }; + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration(), true); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", !newContextCreated.get()); + ssc.stop(); + + // Function to create JavaStreamingContext using existing JavaSparkContext + // without any output operations (used to detect the new context) + Function creatingFunc2 = + new Function() { + public JavaStreamingContext call(JavaSparkContext context) { + newContextCreated.set(true); + return new JavaStreamingContext(context, Seconds.apply(1)); + } + }; + + JavaSparkContext sc = new JavaSparkContext(conf); + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(false); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(false); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + Assert.assertTrue("old context not recovered", !newContextCreated.get()); + ssc.stop(); + } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java new file mode 100644 index 0000000000000..50e8f9fc159c8 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.util.ArrayList; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.collections.Transformer; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.util.WriteAheadLog; +import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; +import org.apache.spark.streaming.util.WriteAheadLogUtils; + +import org.junit.Test; +import org.junit.Assert; + +class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { + int index = -1; + public JavaWriteAheadLogSuiteHandle(int idx) { + index = idx; + } +} + +public class JavaWriteAheadLogSuite extends WriteAheadLog { + + class Record { + long time; + int index; + ByteBuffer buffer; + + public Record(long tym, int idx, ByteBuffer buf) { + index = idx; + time = tym; + buffer = buf; + } + } + private int index = -1; + private ArrayList records = new ArrayList(); + + + // Methods for WriteAheadLog + @Override + public WriteAheadLogRecordHandle write(java.nio.ByteBuffer record, long time) { + index += 1; + records.add(new org.apache.spark.streaming.JavaWriteAheadLogSuite.Record(time, index, record)); + return new JavaWriteAheadLogSuiteHandle(index); + } + + @Override + public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) { + if (handle instanceof JavaWriteAheadLogSuiteHandle) { + int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index; + for (Record record: records) { + if (record.index == reqdIndex) { + return record.buffer; + } + } + } + return null; + } + + @Override + public java.util.Iterator readAll() { + Collection buffers = CollectionUtils.collect(records, new Transformer() { + @Override + public Object transform(Object input) { + return ((Record) input).buffer; + } + }); + return buffers.iterator(); + } + + @Override + public void clean(long threshTime, boolean waitForCompletion) { + for (int i = 0; i < records.size(); i++) { + if (records.get(i).time < threshTime) { + records.remove(i); + i--; + } + } + } + + @Override + public void close() { + records.clear(); + } + + @Test + public void testCustomWAL() { + SparkConf conf = new SparkConf(); + conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); + + String data1 = "data1"; + WriteAheadLogRecordHandle handle = wal.write(ByteBuffer.wrap(data1.getBytes()), 1234); + Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); + Assert.assertTrue(new String(wal.read(handle).array()).equals(data1)); + + wal.write(ByteBuffer.wrap("data2".getBytes()), 1235); + wal.write(ByteBuffer.wrap("data3".getBytes()), 1236); + wal.write(ByteBuffer.wrap("data4".getBytes()), 1237); + wal.clean(1236, false); + + java.util.Iterator dataIterator = wal.readAll(); + ArrayList readData = new ArrayList(); + while (dataIterator.hasNext()) { + readData.add(new String(dataIterator.next().array())); + } + Assert.assertTrue(readData.equals(Arrays.asList("data3", "data4"))); + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 54c30440a6e8d..6b0a3f91d4d06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,9 +430,8 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written - val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c090eaec2928d..23804237bda80 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -43,7 +43,7 @@ import WriteAheadLogSuite._ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 @@ -130,10 +130,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche "Unexpected store result type" ) // Verify the data in write ahead log files is correct - val fileSegments = storeResults.map { _.asInstanceOf[WriteAheadLogBasedStoreResult].segment} - val loggedData = fileSegments.flatMap { segment => - val reader = new WriteAheadLogRandomReader(segment.path, hadoopConf) - val bytes = reader.read(segment) + val walSegments = storeResults.map { result => + result.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle + } + val loggedData = walSegments.flatMap { walSegment => + val fileSegment = walSegment.asInstanceOf[FileBasedWriteAheadLogSegment] + val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) + val bytes = reader.read(fileSegment) reader.close() blockManager.dataDeserialize(generateBlockId(), bytes).toList } @@ -148,13 +151,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche } } - test("WriteAheadLogBasedBlockHandler - cleanup old blocks") { + test("WriteAheadLogBasedBlockHandler - clean old blocks") { withWriteAheadLogBasedBlockHandler { handler => val blocks = Seq.tabulate(10) { i => IteratorBlock(Iterator(1 to i)) } storeBlocks(handler, blocks) val preCleanupLogFiles = getWriteAheadLogFiles() - preCleanupLogFiles.size should be > 1 + require(preCleanupLogFiles.size > 1) // this depends on the number of blocks inserted using generateAndStoreData() manualClock.getTimeMillis() shouldEqual 5000L @@ -218,6 +221,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1) val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b63b37d9f9cef..8317fb9720416 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.WriteAheadLogReader +import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader} import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -59,7 +59,7 @@ class ReceivedBlockTrackerSuite test("block addition, and block to batch allocation") { val receivedBlockTracker = createTracker(setCheckpointDir = false) - receivedBlockTracker.isLogManagerEnabled should be (false) // should be disable by default + receivedBlockTracker.isWriteAheadLogEnabled should be (false) // should be disable by default receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty val blockInfos = generateBlockInfos() @@ -88,7 +88,7 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } - test("block addition, block to batch allocation and cleanup with write ahead log") { + test("block addition, block to batch allocation and clean up with write ahead log") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates // a new log file @@ -113,11 +113,15 @@ class ReceivedBlockTrackerSuite logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") } - // Start tracker and add blocks + // Set WAL configuration conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") - conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1") + require(WriteAheadLogUtils.enableReceiverLog(conf)) + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1) + + // Start tracker and add blocks val tracker1 = createTracker(clock = manualClock) - tracker1.isLogManagerEnabled should be (true) + tracker1.isWriteAheadLogEnabled should be (true) val blockInfos1 = addBlockInfos(tracker1) tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 @@ -171,7 +175,7 @@ class ReceivedBlockTrackerSuite eventually(timeout(10 seconds), interval(10 millisecond)) { getWriteAheadLogFiles() should not contain oldestLogFile } - printLogFiles("After cleanup") + printLogFiles("After clean") // Restart tracker and verify recovered state, specifically whether info about the first // batch has been removed, but not the second batch @@ -192,17 +196,17 @@ class ReceivedBlockTrackerSuite test("setting checkpoint dir but not enabling write ahead log") { // When WAL config is not set, log manager should not be enabled val tracker1 = createTracker(setCheckpointDir = true) - tracker1.isLogManagerEnabled should be (false) + tracker1.isWriteAheadLogEnabled should be (false) // When WAL is explicitly disabled, log manager should not be enabled conf.set("spark.streaming.receiver.writeAheadLog.enable", "false") val tracker2 = createTracker(setCheckpointDir = true) - tracker2.isLogManagerEnabled should be(false) + tracker2.isWriteAheadLogEnabled should be(false) } /** * Create tracker object with the optional provided clock. Use fake clock if you - * want to control time by manually incrementing it to test log cleanup. + * want to control time by manually incrementing it to test log clean. */ def createTracker( setCheckpointDir: Boolean = true, @@ -231,7 +235,7 @@ class ReceivedBlockTrackerSuite def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { - file => new WriteAheadLogReader(file, hadoopConf).toSeq + file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) }.toList @@ -250,7 +254,7 @@ class ReceivedBlockTrackerSuite BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } - /** Create batch cleanup object from the given info */ + /** Create batch clean object from the given info */ def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index b84129fd70dd4..393a360cfe150 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -225,7 +225,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { .setAppName(framework) .set("spark.ui.enabled", "true") .set("spark.streaming.receiver.writeAheadLog.enable", "true") - .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val batchDuration = Milliseconds(500) val tempDirectory = Utils.createTempDir() val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 58353a5f97c8a..5207b7109e69b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.io.File import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -330,6 +332,139 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("getOrCreate") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val checkpointPath = createValidCheckpoint() + + // getOrCreate should recover context with checkpoint path, and recover old configuration + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue") + } + } + + test("getOrCreate with existing SparkContext") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(sparkContext: SparkContext): StreamingContext = { + newContextCreated = true + new StreamingContext(sparkContext, batchDuration) + } + + // Call ssc.stop(stopSparkContext = false) after a body of cody + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val checkpointPath = createValidCheckpoint() + + // StreamingContext.getOrCreate should recover context with checkpoint path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + assert(!ssc.conf.contains("someKey"), + "recovered StreamingContext unexpectedly has old config") + } + } + test("DStream and generated RDD creation sites") { testPackage.test() } @@ -339,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } + + def createValidCheckpoint(): String = { + val testDirectory = Utils.createTempDir().getAbsolutePath() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("someKey", "someValue") + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDirectory) + ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + checkpointDirectory + } + + def createCorruptedCheckpoint(): String = { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) + FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) + checkpointDirectory + } } class TestException(msg: String) extends Exception(msg) @@ -363,7 +522,7 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging } def onStop() { - // no cleanup to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own } } @@ -396,7 +555,7 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) - // no cleanup to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 205ddf6dbe9b0..8de43baabc21d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming +import scala.collection.mutable.Queue + import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ @@ -60,8 +62,28 @@ class UISeleniumSuite ssc } + private def setupStreams(ssc: StreamingContext): Unit = { + val rdds = Queue(ssc.sc.parallelize(1 to 4, 4)) + val inputStream = ssc.queueStream(rdds) + inputStream.foreachRDD { rdd => + rdd.foreach(_ => {}) + rdd.foreach(_ => {}) + } + inputStream.foreachRDD { rdd => + rdd.foreach(_ => {}) + try { + rdd.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException if e.getMessage.contains("Oops") => + } + } + } + test("attaching and detaching a Streaming tab") { withStreamingContext(newSparkStreamingContext()) { ssc => + setupStreams(ssc) + ssc.start() + val sparkUI = ssc.sparkContext.ui.get eventually(timeout(10 seconds), interval(50 milliseconds)) { @@ -77,8 +99,8 @@ class UISeleniumSuite statisticText should contain("Batch interval:") val h4Text = findAll(cssSelector("h4")).map(_.text).toSeq - h4Text should contain("Active Batches (0)") - h4Text should contain("Completed Batches (last 0 out of 0)") + h4Text.exists(_.matches("Active Batches \\(\\d+\\)")) should be (true) + h4Text.exists(_.matches("Completed Batches \\(last \\d+ out of \\d+\\)")) should be (true) findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Status") @@ -86,6 +108,63 @@ class UISeleniumSuite findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Total Delay") } + + val batchLinks = + findAll(cssSelector("""#completed-batches-table a""")).flatMap(_.attribute("href")).toSeq + batchLinks.size should be >= 1 + + // Check a normal batch page + go to (batchLinks.last) // Last should be the first batch, so it will have some jobs + val summaryText = findAll(cssSelector("li strong")).map(_.text).toSeq + summaryText should contain ("Batch Duration:") + summaryText should contain ("Input data size:") + summaryText should contain ("Scheduling delay:") + summaryText should contain ("Processing time:") + summaryText should contain ("Total delay:") + + findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { + List("Output Op Id", "Description", "Duration", "Job Id", "Duration", + "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") + } + + // Check we have 2 output op ids + val outputOpIds = findAll(cssSelector(".output-op-id-cell")).toSeq + outputOpIds.map(_.attribute("rowspan")) should be (List(Some("2"), Some("2"))) + outputOpIds.map(_.text) should be (List("0", "1")) + + // Check job ids + val jobIdCells = findAll(cssSelector( """#batch-job-table a""")).toSeq + jobIdCells.map(_.text) should be (List("0", "1", "2", "3")) + + val jobLinks = jobIdCells.flatMap(_.attribute("href")) + jobLinks.size should be (4) + + // Check stage progress + findAll(cssSelector(""".stage-progress-cell""")).map(_.text).toSeq should be + (List("1/1", "1/1", "1/1", "0/1 (1 failed)")) + + // Check job progress + findAll(cssSelector(""".progress-cell""")).map(_.text).toSeq should be + (List("1/1", "1/1", "1/1", "0/1 (1 failed)")) + + // Check stacktrace + val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.text).toSeq + errorCells should have size 1 + errorCells(0) should include("java.lang.RuntimeException: Oops") + + // Check the job link in the batch page is right + go to (jobLinks(0)) + val jobDetails = findAll(cssSelector("li strong")).map(_.text).toSeq + jobDetails should contain("Status:") + jobDetails should contain("Completed Stages:") + + // Check a batch page without id + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/") + webDriver.getPageSource should include ("Missing id parameter") + + // Check a non-exist batch + go to (sparkUI.appUIAddress.stripSuffix("/") + "/streaming/batch/?id=12345") + webDriver.getPageSource should include ("does not exist") } ssc.stop(false) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index c3602a5b73732..8b300d8dd3fbe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,12 +21,12 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} -import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} +import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -100,9 +100,10 @@ class WriteAheadLogBackedBlockRDDSuite blockManager.putIterator(blockId, block.iterator, StorageLevel.MEMORY_ONLY_SER) } - // Generate write ahead log segments - val segments = generateFakeSegments(numPartitionsInBM) ++ - writeLogSegments(data.takeRight(numPartitionsInWAL), blockIds.takeRight(numPartitionsInWAL)) + // Generate write ahead log file segments + val recordHandles = generateFakeRecordHandles(numPartitionsInBM) ++ + generateWALRecordHandles(data.takeRight(numPartitionsInWAL), + blockIds.takeRight(numPartitionsInWAL)) // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not require( @@ -116,24 +117,24 @@ class WriteAheadLogBackedBlockRDDSuite // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( - segments.takeRight(numPartitionsInWAL).forall(s => + recordHandles.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), "Expected blocks not in write ahead log" ) require( - segments.take(numPartitionsInBM).forall(s => + recordHandles.take(numPartitionsInBM).forall(s => !new File(s.path.stripPrefix("file://")).exists()), "Unexpected blocks in write ahead log" ) // Create the RDD and verify whether the returned data is correct val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, - segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) + recordHandles.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) assert(rdd.collect() === data.flatten) if (testStoreInBM) { val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, - segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) + recordHandles.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( blockIds.forall(blockManager.get(_).nonEmpty), @@ -142,12 +143,12 @@ class WriteAheadLogBackedBlockRDDSuite } } - private def writeLogSegments( + private def generateWALRecordHandles( blockData: Seq[Seq[String]], blockIds: Seq[BlockId] - ): Seq[WriteAheadLogFileSegment] = { + ): Seq[FileBasedWriteAheadLogSegment] = { require(blockData.size === blockIds.size) - val writer = new WriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) + val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => writer.write(blockManager.dataSerialize(id, data.iterator)) } @@ -155,7 +156,7 @@ class WriteAheadLogBackedBlockRDDSuite segments } - private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) + private def generateFakeRecordHandles(count: Int): Seq[FileBasedWriteAheadLogSegment] = { + Array.fill(count)(new FileBasedWriteAheadLogSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 94b1985116feb..fa89536de4054 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.streaming.ui +import java.util.Properties + import org.scalatest.Matchers +import org.apache.spark.scheduler.SparkListenerJobStart import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} @@ -28,6 +31,17 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + private def createJobStart( + batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { + val properties = new Properties() + properties.setProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, batchTime.milliseconds.toString) + properties.setProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, outputOpId.toString) + SparkListenerJobStart(jobId = jobId, + 0L, // unused + Nil, // unused + properties) + } + override def batchDuration: Duration = Milliseconds(100) test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + @@ -43,7 +57,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // onBatchSubmitted val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) - listener.waitingBatches should be (List(batchInfoSubmitted)) + listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) listener.retainedCompletedBatches should be (Nil) listener.lastCompletedBatch should be (None) @@ -56,7 +70,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) - listener.runningBatches should be (List(batchInfoStarted)) + listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) listener.retainedCompletedBatches should be (Nil) listener.lastCompletedBatch should be (None) listener.numUnprocessedBatches should be (1) @@ -64,13 +78,40 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalProcessedRecords should be (0) listener.numTotalReceivedRecords should be (600) + // onJobStart + val jobStart1 = createJobStart(Time(1000), outputOpId = 0, jobId = 0) + listener.onJobStart(jobStart1) + + val jobStart2 = createJobStart(Time(1000), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart2) + + val jobStart3 = createJobStart(Time(1000), outputOpId = 1, jobId = 0) + listener.onJobStart(jobStart3) + + val jobStart4 = createJobStart(Time(1000), outputOpId = 1, jobId = 1) + listener.onJobStart(jobStart4) + + val batchUIData = listener.getBatchUIData(Time(1000)) + batchUIData should not be None + batchUIData.get.batchTime should be (batchInfoStarted.batchTime) + batchUIData.get.schedulingDelay should be (batchInfoStarted.schedulingDelay) + batchUIData.get.processingDelay should be (batchInfoStarted.processingDelay) + batchUIData.get.totalDelay should be (batchInfoStarted.totalDelay) + batchUIData.get.receiverNumRecords should be (Map(0 -> 300L, 1 -> 300L)) + batchUIData.get.numRecords should be(600) + batchUIData.get.outputOpIdSparkJobIdPairs should be + Seq(OutputOpIdAndSparkJobId(0, 0), + OutputOpIdAndSparkJobId(0, 1), + OutputOpIdAndSparkJobId(1, 0), + OutputOpIdAndSparkJobId(1, 1)) + // onBatchCompleted val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) - listener.retainedCompletedBatches should be (List(batchInfoCompleted)) - listener.lastCompletedBatch should be (Some(batchInfoCompleted)) + listener.retainedCompletedBatches should be (List(BatchUIData(batchInfoCompleted))) + listener.lastCompletedBatch should be (Some(BatchUIData(batchInfoCompleted))) listener.numUnprocessedBatches should be (0) listener.numTotalCompletedBatches should be (1) listener.numTotalProcessedRecords should be (600) @@ -116,4 +157,55 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.retainedCompletedBatches.size should be (limit) listener.numTotalCompletedBatches should be(limit + 10) } + + test("out-of-order onJobStart and onBatchXXX") { + val ssc = setupStreams(input, operation) + val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + val listener = new StreamingJobProgressListener(ssc) + + // fulfill completedBatchInfos + for(i <- 0 until limit) { + val batchInfoCompleted = + BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart) + } + + // onJobStart happens before onBatchSubmitted + val jobStart = createJobStart(Time(1000 + limit * 100), outputOpId = 0, jobId = 0) + listener.onJobStart(jobStart) + + val batchInfoSubmitted = + BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None) + listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) + + // We still can see the info retrieved from onJobStart + val batchUIData = listener.getBatchUIData(Time(1000 + limit * 100)) + batchUIData should not be None + batchUIData.get.batchTime should be (batchInfoSubmitted.batchTime) + batchUIData.get.schedulingDelay should be (batchInfoSubmitted.schedulingDelay) + batchUIData.get.processingDelay should be (batchInfoSubmitted.processingDelay) + batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) + batchUIData.get.receiverNumRecords should be (Map.empty) + batchUIData.get.numRecords should be (0) + batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) + + // A lot of "onBatchCompleted"s happen before "onJobStart" + for(i <- limit + 1 to limit * 2) { + val batchInfoCompleted = + BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) + } + + for(i <- limit + 1 to limit * 2) { + val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) + listener.onJobStart(jobStart) + } + + // We should not leak memory + listener.batchTimeToOutputOpIdSparkJobIdPair.size() should be <= + (listener.waitingBatches.size + listener.runningBatches.size + + listener.retainedCompletedBatches.size + 10) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index a3919c43b95b4..79098bcf4861c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,33 +18,38 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer +import java.util import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} +import scala.reflect.ClassTag -import WriteAheadLogSuite._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.{ManualClock, Utils} -import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkException} class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { + import WriteAheadLogSuite._ + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null var testFile: String = null - var manager: WriteAheadLogManager = null + var writeAheadLog: FileBasedWriteAheadLog = null before { tempDir = Utils.createTempDir() testDir = tempDir.toString testFile = new File(tempDir, "testFile").toString - if (manager != null) { - manager.stop() - manager = null + if (writeAheadLog != null) { + writeAheadLog.close() + writeAheadLog = null } } @@ -52,16 +57,60 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogWriter - writing data") { + test("WriteAheadLogUtils - log selection and creation") { + val logDir = Utils.createTempDir().getAbsolutePath() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](conf1) + assertReceiverLogClass[FileBasedWriteAheadLog](conf1) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("FileBasedWriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() val segments = writeDataUsingWriter(testFile, dataToWrite) val writtenData = readDataManually(segments) assert(writtenData === dataToWrite) } - test("WriteAheadLogWriter - syncing of data by writing and reading immediately") { + test("FileBasedWriteAheadLogWriter - syncing of data by writing and reading immediately") { val dataToWrite = generateRandomData() - val writer = new WriteAheadLogWriter(testFile, hadoopConf) + val writer = new FileBasedWriteAheadLogWriter(testFile, hadoopConf) dataToWrite.foreach { data => val segment = writer.write(stringToByteBuffer(data)) val dataRead = readDataManually(Seq(segment)).head @@ -70,10 +119,10 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { writer.close() } - test("WriteAheadLogReader - sequentially reading data") { + test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() writeDataManually(writtenData, testFile) - val reader = new WriteAheadLogReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) assert(reader.hasNext === false) @@ -83,14 +132,14 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { reader.close() } - test("WriteAheadLogReader - sequentially reading data written with writer") { + test("FileBasedWriteAheadLogReader - sequentially reading data written with writer") { val dataToWrite = generateRandomData() writeDataUsingWriter(testFile, dataToWrite) val readData = readDataUsingReader(testFile) assert(readData === dataToWrite) } - test("WriteAheadLogReader - reading data written with writer after corrupted write") { + test("FileBasedWriteAheadLogReader - reading data written with writer after corrupted write") { // Write data manually for testing the sequential reader val dataToWrite = generateRandomData() writeDataUsingWriter(testFile, dataToWrite) @@ -113,38 +162,38 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1))) } - test("WriteAheadLogRandomReader - reading data using random reader") { + test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() val segments = writeDataManually(writtenData, testFile) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten - val reader = new WriteAheadLogRandomReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf) writtenDataAndSegments.foreach { case (data, segment) => assert(data === byteBufferToString(reader.read(segment))) } reader.close() } - test("WriteAheadLogRandomReader - reading data using random reader written with writer") { + test("FileBasedWriteAheadLogRandomReader- reading data using random reader written with writer") { // Write data using writer for testing the random reader val data = generateRandomData() val segments = writeDataUsingWriter(testFile, data) // Read a random sequence of segments and verify read data val dataAndSegments = data.zip(segments).toSeq.permutations.take(10).flatten - val reader = new WriteAheadLogRandomReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf) dataAndSegments.foreach { case (data, segment) => assert(data === byteBufferToString(reader.read(segment))) } reader.close() } - test("WriteAheadLogManager - write rotating logs") { - // Write data using manager + test("FileBasedWriteAheadLog - write rotating logs") { + // Write data with rotation using WriteAheadLog class val dataToWrite = generateRandomData() - writeDataUsingManager(testDir, dataToWrite) + writeDataUsingWriteAheadLog(testDir, dataToWrite) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) @@ -153,8 +202,8 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(writtenData === dataToWrite) } - test("WriteAheadLogManager - read rotating logs") { - // Write data manually for testing reading through manager + test("FileBasedWriteAheadLog - read rotating logs") { + // Write data manually for testing reading through WriteAheadLog val writtenData = (1 to 10).map { i => val data = generateRandomData() val file = testDir + s"/log-$i-$i" @@ -167,25 +216,25 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(fileSystem.exists(logDirectoryPath) === true) // Read data using manager and verify - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(readData === writtenData) } - test("WriteAheadLogManager - recover past logs when creating new manager") { + test("FileBasedWriteAheadLog - recover past logs when creating new manager") { // Write data with manager, recover with new manager and verify val dataToWrite = generateRandomData() - writeDataUsingManager(testDir, dataToWrite) + writeDataUsingWriteAheadLog(testDir, dataToWrite) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(dataToWrite === readData) } - test("WriteAheadLogManager - cleanup old logs") { + test("FileBasedWriteAheadLog - clean old logs") { logCleanUpTest(waitForCompletion = false) } - test("WriteAheadLogManager - cleanup old logs synchronously") { + test("FileBasedWriteAheadLog - clean old logs synchronously") { logCleanUpTest(waitForCompletion = true) } @@ -193,11 +242,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Write data with manager, recover with new manager and verify val manualClock = new ManualClock val dataToWrite = generateRandomData() - manager = writeDataUsingManager(testDir, dataToWrite, manualClock, stopManager = false) + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - manager.cleanupOldLogs(manualClock.getTimeMillis() / 2, waitForCompletion) + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) if (waitForCompletion) { assert(getLogFilesInDirectory(testDir).size < logFiles.size) @@ -208,11 +257,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } } - test("WriteAheadLogManager - handling file errors while reading rotating logs") { + test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { // Generate a set of log files val manualClock = new ManualClock val dataToWrite1 = generateRandomData() - writeDataUsingManager(testDir, dataToWrite1, manualClock) + writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) val logFiles1 = getLogFilesInDirectory(testDir) assert(logFiles1.size > 1) @@ -220,12 +269,12 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Recover old files and generate a second set of log files val dataToWrite2 = generateRandomData() manualClock.advance(100000) - writeDataUsingManager(testDir, dataToWrite2, manualClock) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) val logFiles2 = getLogFilesInDirectory(testDir) assert(logFiles2.size > logFiles1.size) // Read the files and verify that all the written data can be read - val readData1 = readDataUsingManager(testDir) + val readData1 = readDataUsingWriteAheadLog(testDir) assert(readData1 === (dataToWrite1 ++ dataToWrite2)) // Corrupt the first set of files so that they are basically unreadable @@ -236,25 +285,51 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(readData === dataToWrite2) } + + test("FileBasedWriteAheadLog - do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile) + val wal = new FileBasedWriteAheadLog( + new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + wal.read(writtenSegment.head) + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + } } object WriteAheadLogSuite { + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() + + private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[WriteAheadLogFileSegment] = { - val segments = new ArrayBuffer[WriteAheadLogFileSegment]() + def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) data.foreach { item => val offset = writer.getPos val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) - segments += WriteAheadLogFileSegment(file, offset, bytes.size) + segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } writer.close() segments @@ -263,8 +338,11 @@ object WriteAheadLogSuite { /** * Write data to a file using the writer class and return an array of the file segments written. */ - def writeDataUsingWriter(filePath: String, data: Seq[String]): Seq[WriteAheadLogFileSegment] = { - val writer = new WriteAheadLogWriter(filePath, hadoopConf) + def writeDataUsingWriter( + filePath: String, + data: Seq[String] + ): Seq[FileBasedWriteAheadLogSegment] = { + val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) } @@ -272,27 +350,27 @@ object WriteAheadLogSuite { segments } - /** Write data to rotating files in log directory using the manager class. */ - def writeDataUsingManager( + /** Write data to rotating files in log directory using the WriteAheadLog class. */ + def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], manualClock: ManualClock = new ManualClock, - stopManager: Boolean = true - ): WriteAheadLogManager = { + closeLog: Boolean = true + ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val manager = new WriteAheadLogManager(logDirectory, hadoopConf, - rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock) + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) - manager.writeToLog(item) + wal.write(item, manualClock.getTimeMillis()) } - if (stopManager) manager.stop() - manager + if (closeLog) wal.close() + wal } /** Read data from a segments of a log file directly and return the list of byte buffers. */ - def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { + def readDataManually(segments: Seq[FileBasedWriteAheadLogSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) try { @@ -331,18 +409,18 @@ object WriteAheadLogSuite { /** Read all the data from a log file using reader class and return the list of byte buffers. */ def readDataUsingReader(file: String): Seq[String] = { - val reader = new WriteAheadLogReader(file, hadoopConf) + val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) val readData = reader.toList.map(byteBufferToString) reader.close() readData } - /** Read all the data in the log file in a directory using the manager class. */ - def readDataUsingManager(logDirectory: String): Seq[String] = { - val manager = new WriteAheadLogManager(logDirectory, hadoopConf, - callerName = "WriteAheadLogSuite") - val data = manager.readFromLog().map(byteBufferToString).toSeq - manager.stop() + /** Read all the data in the log file in a directory using the WriteAheadLog class. */ + def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { + import scala.collection.JavaConversions._ + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + val data = wal.readAll().map(byteBufferToString).toSeq + wal.close() data } diff --git a/unsafe/pom.xml b/unsafe/pom.xml new file mode 100644 index 0000000000000..5b0733206b2bc --- /dev/null +++ b/unsafe/pom.xml @@ -0,0 +1,93 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.4.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-unsafe_2.10 + jar + Spark Project Unsafe + http://spark.apache.org/ + + unsafe + + + + + + + com.google.code.findbugs + jsr305 + + + + + org.slf4j + slf4j-api + provided + + + + + junit + junit + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + net.alchim31.maven + scala-maven-plugin + + + + -XDignore.symbol.file + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + -XDignore.symbol.file + + + + + + + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java new file mode 100644 index 0000000000000..24b2892098059 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class PlatformDependent { + + /** + * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of + * this package. This also lets us aovid accidental use of deprecated methods or methods that + * aren't present in Java 6. + */ + public static final class UNSAFE { + + private UNSAFE() { } + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + } + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } + + static public void copyMemory( + Object src, + long srcOffset, + Object dst, + long dstOffset, + long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java new file mode 100644 index 0000000000000..53eadf96a6b52 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; + +public class ByteArrayMethods { + + private ByteArrayMethods() { + // Private constructor, since this class only contains static methods. + } + + public static int roundNumberOfBytesToNearestWord(int numBytes) { + int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + if (remainder == 0) { + return numBytes; + } else { + return numBytes + (8 - remainder); + } + } + + /** + * Optimized byte array equality check for 8-byte-word-aligned byte arrays. + * @return true if the arrays are equal, false otherwise + */ + public static boolean wordAlignedArrayEquals( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset, + long arrayLengthInBytes) { + for (int i = 0; i < arrayLengthInBytes; i += 8) { + final long left = + PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i); + final long right = + PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i); + if (left != right) return false; + } + return true; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java new file mode 100644 index 0000000000000..18d1f0d2d7eb2 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * An array of long values. Compared with native JVM arrays, this: + *
    + *
  • supports using both in-heap and off-heap memory
  • + *
  • has no bound checking, and thus can crash the JVM process when assert is turned off
  • + *
+ */ +public final class LongArray { + + // This is a long so that we perform long multiplications when computing offsets. + private static final long WIDTH = 8; + + private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; + + private final long length; + + public LongArray(MemoryBlock memory) { + assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; + assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; + this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); + this.length = memory.size() / WIDTH; + } + + public MemoryBlock memoryBlock() { + return memory; + } + + /** + * Returns the number of elements this array can hold. + */ + public long size() { + return length; + } + + /** + * Sets the value at position {@code index}. + */ + public void set(int index, long value) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + } + + /** + * Returns the value at position {@code index}. + */ + public long get(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java new file mode 100644 index 0000000000000..f72e07fce92fd --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * A fixed size uncompressed bit set backed by a {@link LongArray}. + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSet { + + /** A long array for the bits. */ + private final LongArray words; + + /** Length of the long array. */ + private final int numWords; + + private final Object baseObject; + private final long baseOffset; + + /** + * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be + * multiple of 8 bytes (i.e. 64 bits). + */ + public BitSet(MemoryBlock memory) { + words = new LongArray(memory); + assert (words.size() <= Integer.MAX_VALUE); + numWords = (int) words.size(); + baseObject = words.memoryBlock().getBaseObject(); + baseOffset = words.memoryBlock().getBaseOffset(); + } + + public MemoryBlock memoryBlock() { + return words.memoryBlock(); + } + + /** + * Returns the number of bits in this {@code BitSet}. + */ + public long capacity() { + return numWords * 64; + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public void set(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.set(baseObject, baseOffset, index); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public void unset(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.unset(baseObject, baseOffset, index); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public boolean isSet(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + return BitSetMethods.isSet(baseObject, baseOffset, index); + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

+ * To iterate over the true bits in a BitSet, use the following loop: + *

+   * 
+   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
+   *    // operate on index i here
+   *  }
+   * 
+   * 
+ * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + public int nextSetBit(int fromIndex) { + return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java new file mode 100644 index 0000000000000..f30626d8f4317 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Methods for working with fixed-size uncompressed bitsets. + * + * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length). + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSetMethods { + + private static final long WORD_SIZE = 8; + + private BitSetMethods() { + // Make the default constructor private, since this only holds static methods. + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public static void set(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public static void unset(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public static boolean isSet(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + return (word & mask) != 0; + } + + /** + * Returns {@code true} if any bit is set. + */ + public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInBytes) { + for (int i = 0; i <= bitSetWidthInBytes; i++) { + if (PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i) != 0) { + return true; + } + } + return false; + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

+ * To iterate over the true bits in a BitSet, use the following loop: + *

+   * 
+   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+   *    // operate on index i here
+   *  }
+   * 
+   * 
+ * + * @param fromIndex the index to start checking from (inclusive) + * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words + * @return the index of the next set bit, or -1 if there is no such bit + */ + public static int nextSetBit( + Object baseObject, + long baseOffset, + int fromIndex, + int bitsetSizeInWords) { + int wi = fromIndex >> 6; + if (wi >= bitsetSizeInWords) { + return -1; + } + + // Try to find the next set bit in the current word + final int subIndex = fromIndex & 0x3f; + long word = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + if (word != 0) { + return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); + } + + // Find the next set bit in the rest of the words + wi += 1; + while (wi < bitsetSizeInWords) { + word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + if (word != 0) { + return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); + } + wi += 1; + } + + return -1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java new file mode 100644 index 0000000000000..85cd02469adb7 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +public final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = seed; + for (int offset = 0; offset < lengthInBytes; offset += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + public int hashLong(long input) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java new file mode 100644 index 0000000000000..4e5ebc402be35 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Override; +import java.lang.UnsupportedOperationException; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.bitset.BitSet; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.*; + +/** + * An append-only hash map where keys and values are contiguous regions of bytes. + *

+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, + * which is guaranteed to exhaust the space. + *

+ * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is + * higher than this, you should probably be using sorting instead of hashing for better cache + * locality. + *

+ * This class is not thread safe. + */ +public final class BytesToBytesMap { + + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); + + private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + + private final TaskMemoryManager memoryManager; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List dataPages = new LinkedList(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. + */ + private long pageCursor = 0; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + // This choice of page table size and page size means that we can address up to 500 gigabytes + // of memory. + + /** + * A single array to store the key and value. + * + * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, + * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. + */ + private LongArray longArray; + // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode + // and exploit word-alignment to use fewer bits to hold the address. This might let us store + // only one long per map entry, increasing the chance that this array will fit in cache at the + // expense of maybe performing more lookups if we have hash collisions. Say that we stored only + // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte + // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a + // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store + // full base addresses in the page table for off-heap mode so that we can reconstruct the full + // absolute memory addresses. + + /** + * A {@link BitSet} used to track location of the map where the key is set. + * Size of the bitset should be half of the size of the long array. + */ + private BitSet bitset; + + private final double loadFactor; + + /** + * Number of keys defined in the map. + */ + private int size; + + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ + private int growthThreshold; + + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + * This is a strength reduction optimization; we're essentially performing a modulus operation, + * but doing so with a bitmask because this is a power-of-2-sized hash map. + */ + private int mask; + + /** + * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + */ + private final Location loc; + + private final boolean enablePerfMetrics; + + private long timeSpentResizingNs = 0; + + private long numProbes = 0; + + private long numKeyLookups = 0; + + private long numHashCollisions = 0; + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + double loadFactor, + boolean enablePerfMetrics) { + this.memoryManager = memoryManager; + this.loadFactor = loadFactor; + this.loc = new Location(); + this.enablePerfMetrics = enablePerfMetrics; + allocate(initialCapacity); + } + + public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { + this(memoryManager, initialCapacity, 0.70, false); + } + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + } + + /** + * Returns the number of keys defined in the map. + */ + public int size() { return size; } + + /** + * Returns an iterator for iterating over the entries of this map. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public Iterator iterator() { + return new Iterator() { + + private int nextPos = bitset.nextSetBit(0); + + @Override + public boolean hasNext() { + return nextPos != -1; + } + + @Override + public Location next() { + final int pos = nextPos; + nextPos = bitset.nextSetBit(nextPos + 1); + return loc.with(pos, 0, true); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes) { + if (enablePerfMetrics) { + numKeyLookups++; + } + final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + int pos = hashcode & mask; + int step = 1; + while (true) { + if (enablePerfMetrics) { + numProbes++; + } + if (!bitset.isSet(pos)) { + // This is a new key. + return loc.with(pos, hashcode, false); + } else { + long stored = longArray.get(pos * 2 + 1); + if ((int) (stored) == hashcode) { + // Full hash code matches. Let's compare the keys for equality. + loc.with(pos, hashcode, true); + if (loc.getKeyLength() == keyRowLengthBytes) { + final MemoryLocation keyAddress = loc.getKeyAddress(); + final Object storedKeyBaseObject = keyAddress.getBaseObject(); + final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals( + keyBaseObject, + keyBaseOffset, + storedKeyBaseObject, + storedKeyBaseOffset, + keyRowLengthBytes + ); + if (areEqual) { + return loc; + } else { + if (enablePerfMetrics) { + numHashCollisions++; + } + } + } + } + } + pos = (pos + step) & mask; + step++; + } + } + + /** + * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. + */ + public final class Location { + /** An index into the hash map's Long array */ + private int pos; + /** True if this location points to a position where a key is defined, false otherwise */ + private boolean isDefined; + /** + * The hashcode of the most recent key passed to + * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to + * avoid re-hashing the key when storing a value for that key. + */ + private int keyHashcode; + private final MemoryLocation keyMemoryLocation = new MemoryLocation(); + private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private int keyLength; + private int valueLength; + + private void updateAddressesAndSizes(long fullKeyAddress) { + final Object page = memoryManager.getPage(fullKeyAddress); + final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + long position = keyOffsetInPage; + keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + keyMemoryLocation.setObjAndOffset(page, position); + position += keyLength; + valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + valueMemoryLocation.setObjAndOffset(page, position); + } + + Location with(int pos, int keyHashcode, boolean isDefined) { + this.pos = pos; + this.isDefined = isDefined; + this.keyHashcode = keyHashcode; + if (isDefined) { + final long fullKeyAddress = longArray.get(pos * 2); + updateAddressesAndSizes(fullKeyAddress); + } + return this; + } + + /** + * Returns true if the key is defined at this position, and false otherwise. + */ + public boolean isDefined() { + return isDefined; + } + + /** + * Returns the address of the key defined at this position. + * This points to the first byte of the key data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getKeyAddress() { + assert (isDefined); + return keyMemoryLocation; + } + + /** + * Returns the length of the key defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getKeyLength() { + assert (isDefined); + return keyLength; + } + + /** + * Returns the address of the value defined at this position. + * This points to the first byte of the value data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getValueAddress() { + assert (isDefined); + return valueMemoryLocation; + } + + /** + * Returns the length of the value defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getValueLength() { + assert (isDefined); + return valueLength; + } + + /** + * Store a new key and value. This method may only be called once for a given key; if you want + * to update the value associated with a key, then you can directly manipulate the bytes stored + * at the value address. + *

+ * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` + * will return information on the data stored by this `putNewKey` call. + *

+ * As an example usage, here's the proper way to store a new key: + *

+ *

+     *   Location loc = map.lookup(keyBaseOffset, keyBaseObject, keyLengthInBytes);
+     *   if (!loc.isDefined()) {
+     *     loc.putNewKey(keyBaseOffset, keyBaseObject, keyLengthInBytes, ...)
+     *   }
+     * 
+ *

+ * Unspecified behavior if the key is not defined. + */ + public void putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes) { + assert (!isDefined) : "Can only set value once for a key"; + isDefined = true; + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (8 byte value length) (value) + final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + assert(requiredSize <= PAGE_SIZE_BYTES); + size++; + bitset.set(pos); + + // If there's not enough space in the current page, allocate a new page: + if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + } + + // Compute all of our offsets up-front: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long pageBaseOffset = currentDataPage.getBaseOffset(); + final long keySizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the key size + final long keyDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += keyLengthBytes; + final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the value size + final long valueDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += valueLengthBytes; + + // Copy the key + PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); + PlatformDependent.copyMemory( + keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); + PlatformDependent.copyMemory( + valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + + final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( + currentDataPage, keySizeOffsetInPage); + longArray.set(pos * 2, storedKeyAddress); + longArray.set(pos * 2 + 1, keyHashcode); + updateAddressesAndSizes(storedKeyAddress); + isDefined = true; + if (size > growthThreshold) { + growAndRehash(); + } + } + } + + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ + private void allocate(int capacity) { + capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); + longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); + + this.growthThreshold = (int) (capacity * loadFactor); + this.mask = capacity - 1; + } + + /** + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent. + */ + public void free() { + if (longArray != null) { + memoryManager.free(longArray.memoryBlock()); + longArray = null; + } + if (bitset != null) { + // The bitset's heap memory isn't managed by a memory manager, so no need to free it here. + bitset = null; + } + Iterator dataPagesIterator = dataPages.iterator(); + while (dataPagesIterator.hasNext()) { + memoryManager.freePage(dataPagesIterator.next()); + dataPagesIterator.remove(); + } + assert(dataPages.isEmpty()); + } + + /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + public long getTotalMemoryConsumption() { + return ( + dataPages.size() * PAGE_SIZE_BYTES + + bitset.memoryBlock().size() + + longArray.memoryBlock().size()); + } + + /** + * Returns the total amount of time spent resizing this map (in nanoseconds). + */ + public long getTimeSpentResizingNs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingNs; + } + + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + + public long getNumHashCollisions() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return numHashCollisions; + } + + /** + * Grows the size of the hash table and re-hash everything. + */ + private void growAndRehash() { + long resizeStartTime = -1; + if (enablePerfMetrics) { + resizeStartTime = System.nanoTime(); + } + // Store references to the old data structures to be used when we re-hash + final LongArray oldLongArray = longArray; + final BitSet oldBitSet = bitset; + final int oldCapacity = (int) oldBitSet.capacity(); + + // Allocate the new data structures + allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + + // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) + for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { + final long keyPointer = oldLongArray.get(pos * 2); + final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + int newPos = hashcode & mask; + int step = 1; + boolean keepGoing = true; + + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!bitset.isSet(newPos)) { + bitset.set(newPos); + longArray.set(newPos * 2, keyPointer); + longArray.set(newPos * 2 + 1, hashcode); + keepGoing = false; + } else { + newPos = (newPos + step) & mask; + step++; + } + } + } + + // Deallocate the old data structures. + memoryManager.free(oldLongArray.memoryBlock()); + if (enablePerfMetrics) { + timeSpentResizingNs += System.nanoTime() - resizeStartTime; + } + } + + /** Returns the next number greater or equal num that is power of 2. */ + private static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java new file mode 100644 index 0000000000000..7c321baffe82d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +/** + * Interface that defines how we can grow the size of a hash map when it is over a threshold. + */ +public interface HashMapGrowthStrategy { + + int nextCapacity(int currentCapacity); + + /** + * Double the size of the hash map every time. + */ + HashMapGrowthStrategy DOUBLING = new Doubling(); + + class Doubling implements HashMapGrowthStrategy { + @Override + public int nextCapacity(int currentCapacity) { + return currentCapacity * 2; + } + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java new file mode 100644 index 0000000000000..62c29c8cc1e4d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * Manages memory for an executor. Individual operators / tasks allocate memory through + * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. + */ +public class ExecutorMemoryManager { + + /** + * Allocator, exposed for enabling untracked allocations of temporary data structures. + */ + public final MemoryAllocator allocator; + + /** + * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. + */ + final boolean inHeap; + + /** + * Construct a new ExecutorMemoryManager. + * + * @param allocator the allocator that will be used + */ + public ExecutorMemoryManager(MemoryAllocator allocator) { + this.inHeap = allocator instanceof HeapMemoryAllocator; + this.allocator = allocator; + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError { + return allocator.allocate(size); + } + + void free(MemoryBlock memory) { + allocator.free(memory); + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java new file mode 100644 index 0000000000000..bbe83d36cf36b --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. + */ +public class HeapMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long[] array = new long[(int) (size / 8)]; + return MemoryBlock.fromLongArray(array); + } + + @Override + public void free(MemoryBlock memory) { + // Do nothing + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java new file mode 100644 index 0000000000000..5192f68c862cf --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +public interface MemoryAllocator { + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError; + + void free(MemoryBlock memory); + + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + + MemoryAllocator HEAP = new HeapMemoryAllocator(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java new file mode 100644 index 0000000000000..3dc82d8c2eb39 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class MemoryBlock extends MemoryLocation { + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * MemoryManager. This is package-private and is modified by MemoryManager. + */ + int pageNumber = -1; + + MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java new file mode 100644 index 0000000000000..74ebc87dc978c --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java new file mode 100644 index 0000000000000..9224988e6ad69 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import java.util.*; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages the memory allocated by an individual task. + *

+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + *

+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + *

+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is + * approximately 35 terabytes of memory. + */ +public final class TaskMemoryManager { + + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** + * The number of entries in the page table. + */ + private static final int PAGE_TABLE_SIZE = 1 << 13; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both in- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an in-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + /** + * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean + * up leaked memory. + */ + private final HashSet allocatedNonPageMemory = new HashSet(); + + private final ExecutorMemoryManager executorMemoryManager; + + /** + * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + private final boolean inHeap; + + /** + * Construct a new MemoryManager. + */ + public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { + this.inHeap = executorMemoryManager.inHeap; + this.executorMemoryManager = executorMemoryManager; + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of memory that will be shared between operators. + */ + public MemoryBlock allocatePage(long size) { + if (logger.isTraceEnabled()) { + logger.trace("Allocating {} byte page", size); + } + if (size >= (1L << 51)) { + throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + final MemoryBlock page = executorMemoryManager.allocate(size); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isDebugEnabled()) { + logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + */ + public void freePage(MemoryBlock page) { + if (logger.isTraceEnabled()) { + logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); + } + assert (page.pageNumber != -1) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + executorMemoryManager.free(page); + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + pageTable[page.pageNumber] = null; + if (logger.isDebugEnabled()) { + logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended + * to be used for allocating operators' internal data structures. For data pages that you want to + * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since + * that will enable intra-memory pointers (see + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's + * top-level Javadoc for more details). + */ + public MemoryBlock allocate(long size) throws OutOfMemoryError { + final MemoryBlock memory = executorMemoryManager.allocate(size); + allocatedNonPageMemory.add(memory); + return memory; + } + + /** + * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. + */ + public void free(MemoryBlock memory) { + assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; + executorMemoryManager.free(memory); + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (inHeap) { + assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } else { + return offsetInPage; + } + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (inHeap) { + final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final Object page = pageTable[pageNumber].getBaseObject(); + assert (page != null); + return page; + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + if (inHeap) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } else { + return pagePlusOffsetAddress; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; + for (MemoryBlock page : pageTable) { + if (page != null) { + freedBytes += page.size(); + freePage(page); + } + } + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } + return freedBytes; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java new file mode 100644 index 0000000000000..15898771fef25 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. + */ +public class UnsafeMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long address = PlatformDependent.UNSAFE.allocateMemory(size); + return new MemoryBlock(null, address, size); + } + + @Override + public void free(MemoryBlock memory) { + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + PlatformDependent.UNSAFE.freeMemory(memory.offset); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java new file mode 100644 index 0000000000000..5974cf91ff993 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class LongArraySuite { + + @Test + public void basicTest() { + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + arr.set(0, 1L); + arr.set(1, 2L); + arr.set(1, 3L); + Assert.assertEquals(2, arr.size()); + Assert.assertEquals(1L, arr.get(0)); + Assert.assertEquals(3L, arr.get(1)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java new file mode 100644 index 0000000000000..e3a824e29b768 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import junit.framework.Assert; +import org.apache.spark.unsafe.bitset.BitSet; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class BitSetSuite { + + private static BitSet createBitSet(int capacity) { + assert capacity % 64 == 0; + return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); + } + + @Test + public void basicOps() { + BitSet bs = createBitSet(64); + Assert.assertEquals(64, bs.capacity()); + + // Make sure the bit set starts empty. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertFalse(bs.isSet(i)); + } + + // Set every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + bs.set(i); + Assert.assertTrue(bs.isSet(i)); + } + + // Unset every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertTrue(bs.isSet(i)); + bs.unset(i); + Assert.assertFalse(bs.isSet(i)); + } + } + + @Test + public void traversal() { + BitSet bs = createBitSet(256); + + Assert.assertEquals(-1, bs.nextSetBit(0)); + Assert.assertEquals(-1, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(64)); + + bs.set(10); + Assert.assertEquals(10, bs.nextSetBit(0)); + Assert.assertEquals(10, bs.nextSetBit(1)); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(11)); + + bs.set(11); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(11, bs.nextSetBit(11)); + + // Skip a whole word and find it + bs.set(190); + Assert.assertEquals(190, bs.nextSetBit(12)); + + Assert.assertEquals(-1, bs.nextSetBit(191)); + Assert.assertEquals(-1, bs.nextSetBit(256)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java new file mode 100644 index 0000000000000..3b9175835229c --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import junit.framework.Assert; +import org.apache.spark.unsafe.PlatformDependent; +import org.junit.Test; + +/** + * Test file based on Guava's Murmur3Hash32Test. + */ +public class Murmur3_x86_32Suite { + + private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0); + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(593689054, hasher.hashInt(0)); + Assert.assertEquals(-189366624, hasher.hashInt(-42)); + Assert.assertEquals(-1134849565, hasher.hashInt(42)); + Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE)); + Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(1669671676, hasher.hashLong(0L)); + Assert.assertEquals(-846261623, hasher.hashLong(-42L)); + Assert.assertEquals(1871679806, hasher.hashLong(42L)); + Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = ("" + i).getBytes(); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java new file mode 100644 index 0000000000000..7a5c0622d1ffb --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Exception; +import java.nio.ByteBuffer; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.PlatformDependent; +import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public abstract class AbstractBytesToBytesMapSuite { + + private final Random rand = new Random(42); + + private TaskMemoryManager memoryManager; + + @Before + public void setup() { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + } + + @After + public void tearDown() { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory(); + memoryManager = null; + } + } + + protected abstract MemoryAllocator getMemoryAllocator(); + + private static byte[] getByteArray(MemoryLocation loc, int size) { + final byte[] arr = new byte[size]; + PlatformDependent.copyMemory( + loc.getBaseObject(), + loc.getBaseOffset(), + arr, + BYTE_ARRAY_OFFSET, + size + ); + return arr; + } + + private byte[] getRandomByteArray(int numWords) { + Assert.assertTrue(numWords > 0); + final int lengthInBytes = numWords * 8; + final byte[] bytes = new byte[lengthInBytes]; + rand.nextBytes(bytes); + return bytes; + } + + /** + * Fast equality checking for byte arrays, since these comparisons are a bottleneck + * in our stress tests. + */ + private static boolean arrayEquals( + byte[] expected, + MemoryLocation actualAddr, + long actualLengthBytes) { + return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals( + expected, + BYTE_ARRAY_OFFSET, + actualAddr.getBaseObject(), + actualAddr.getBaseOffset(), + expected.length + ); + } + + @Test + public void emptyMap() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + try { + Assert.assertEquals(0, map.size()); + final int keyLengthInWords = 10; + final int keyLengthInBytes = keyLengthInWords * 8; + final byte[] key = getRandomByteArray(keyLengthInWords); + Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + } finally { + map.free(); + } + } + + @Test + public void setAndRetrieveAKey() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + final int recordLengthWords = 10; + final int recordLengthBytes = recordLengthWords * 8; + final byte[] keyData = getRandomByteArray(recordLengthWords); + final byte[] valueData = getRandomByteArray(recordLengthWords); + try { + final BytesToBytesMap.Location loc = + map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + // After storing the key and value, the other location methods should return results that + // reflect the result of this store without us having to call lookup() again on the same key. + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + // After calling lookup() the location should still point to the correct data. + Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + try { + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + Assert.fail("Should not be able to set a new value for a key"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + } finally { + map.free(); + } + } + + @Test + public void iteratorTest() throws Exception { + final int size = 128; + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + try { + for (long i = 0; i < size; i++) { + final long[] value = new long[] { i }; + final BytesToBytesMap.Location loc = + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } + final java.util.BitSet valuesSeen = new java.util.BitSet(size); + final Iterator iter = map.iterator(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long value = PlatformDependent.UNSAFE.getLong( + valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + Assert.assertEquals(key, value); + valuesSeen.set((int) value); + } + Assert.assertEquals(size, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + + @Test + public void randomizedStressTest() { + final int size = 65536; + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map expected = new HashMap(); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + + try { + // Fill the map to 90% full so that we can trigger probing + for (int i = 0; i < size * 0.9; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + BYTE_ARRAY_OFFSET, + key.length, + value, + BYTE_ARRAY_OFFSET, + value.length + ); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + + for (Map.Entry entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java new file mode 100644 index 0000000000000..5a10de49f54fe --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.UNSAFE; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java new file mode 100644 index 0000000000000..12cc9b25d93b3 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.HEAP; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java new file mode 100644 index 0000000000000..932882f1ca248 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.junit.Assert; +import org.junit.Test; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedNonPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocate(1024); // leak memory + Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 741239c953794..4abcf7307a388 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -39,7 +39,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.Token +import org.apache.hadoop.security.token.{TokenIdentifier, Token} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -226,6 +226,7 @@ private[spark] class Client( val distributedUris = new HashSet[String] obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) + obtainTokenForHBase(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -1084,6 +1085,41 @@ object Client extends Logging { } } + /** + * Obtain security token for HBase. + */ + def obtainTokenForHBase(conf: Configuration, credentials: Credentials): Unit = { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + + logDebug("Attempting to fetch HBase security token.") + + val hbaseConf = confCreate.invoke(null, conf) + val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] + credentials.addToken(token.getService, token) + + logInfo("Added HBase security token to credentials.") + } catch { + case e:java.lang.NoSuchMethodException => + logInfo("HBase Method not found: " + e) + case e:java.lang.ClassNotFoundException => + logDebug("HBase Class not found: " + e) + case e:java.lang.NoClassDefFoundError => + logDebug("HBase Class not found: " + e) + case e:Exception => + logError("Exception when obtaining HBase security token: " + e) + } + } + } + /** * Return whether the two file systems are the same. */