diff --git a/.gitignore b/.gitignore index 061c8946d23c1..5b56a67c883e6 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,4 @@ metastore_db/ metastore/ warehouse/ TempStatsStore/ +sql/hive-thriftserver/test_warehouses diff --git a/assembly/pom.xml b/assembly/pom.xml index 4f6aade133db7..703f15925bc44 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -39,6 +39,7 @@ spark /usr/share/spark root + 744 @@ -164,6 +165,16 @@ + + hive-thriftserver + + + org.apache.spark + spark-hive-thriftserver_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl @@ -276,7 +287,7 @@ ${deb.user} ${deb.user} ${deb.install.path}/bin - 744 + ${deb.bin.filemode} diff --git a/bagel/pom.xml b/bagel/pom.xml index 90c4b095bb611..bd51b112e26fa 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-bagel_2.10 - bagel + bagel jar Spark Project Bagel diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 70a99b33d753c..ef0bb2ac13f08 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -72,6 +72,7 @@ object Bagel extends Logging { var verts = vertices var msgs = messages var noActivity = false + var lastRDD: RDD[(K, (V, Array[M]))] = null do { logInfo("Starting superstep " + superstep + ".") val startTime = System.currentTimeMillis @@ -83,6 +84,10 @@ object Bagel extends Logging { val superstep_ = superstep // Create a read-only copy of superstep for capture in closure val (processed, numMsgs, numActiveVerts) = comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) + if (lastRDD != null) { + lastRDD.unpersist(false) + } + lastRDD = processed val timeTaken = System.currentTimeMillis - startTime logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) diff --git a/bin/beeline b/bin/beeline new file mode 100755 index 0000000000000..09fe366c609fa --- /dev/null +++ b/bin/beeline @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +# Find the java binary +if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" +else + if [ `command -v java` ]; then + RUNNER="java" + else + echo "JAVA_HOME is not set" >&2 + exit 1 + fi +fi + +# Compute classpath using external script +classpath_output=$($FWDIR/bin/compute-classpath.sh) +if [[ "$?" != "0" ]]; then + echo "$classpath_output" + exit 1 +else + CLASSPATH=$classpath_output +fi + +CLASS="org.apache.hive.beeline.BeeLine" +exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index e81e8c060cb98..16b794a1592e8 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -52,6 +52,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" fi diff --git a/bin/spark-shell b/bin/spark-shell index 850e9507ec38f..756c8179d12b6 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -46,11 +46,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit spark-shell "$@" --class org.apache.spark.repl.Main + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 4b9708a8c03f3..b56d69801171c 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell %* --class org.apache.spark.repl.Main +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql new file mode 100755 index 0000000000000..bba7f897b19bc --- /dev/null +++ b/bin/spark-sql @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL CLI + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/spark-sql [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/core/pom.xml b/core/pom.xml index 1054cec4d77bb..a24743495b0e1 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-core_2.10 - core + core jar Spark Project Core diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8052499ab7526..3e6addeaf04a8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1037,7 +1037,7 @@ class SparkContext(config: SparkConf) extends Logging { */ private[spark] def getCallSite(): CallSite = { Option(getLocalProperty("externalCallSite")) match { - case Some(callSite) => CallSite(callSite, long = "") + case Some(callSite) => CallSite(callSite, longForm = "") case None => Utils.getCallSite } } @@ -1059,11 +1059,12 @@ class SparkContext(config: SparkConf) extends Logging { } val callSite = getCallSite val cleanedFunc = clean(func) - logInfo("Starting job: " + callSite.short) + logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) - logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo( + "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() } @@ -1144,11 +1145,12 @@ class SparkContext(config: SparkConf) extends Logging { evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { val callSite = getCallSite - logInfo("Starting job: " + callSite.short) + logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.get) - logInfo("Job finished: " + callSite.short + ", took " + (System.nanoTime - start) / 1e9 + " s") + logInfo( + "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") result } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 462e09466bfa6..d6b0988641a97 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -57,7 +57,10 @@ private[spark] class PythonRDD[T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val localdir = env.blockManager.diskBlockManager.localDirs.map( + f => f.getPath()).mkString(",") + val worker: Socket = env.createPythonWorker(pythonExec, + envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir)) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3d8373d8175ee..c9cec33ebaa66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -46,6 +46,10 @@ object SparkSubmit { private val CLUSTER = 2 private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + // A special jar name that indicates the class being run is inside of Spark itself, and therefore + // no user jar is needed. + private val SPARK_INTERNAL = "spark-internal" + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" @@ -257,7 +261,9 @@ object SparkSubmit { // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (clusterManager == YARN && deployMode == CLUSTER) { childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) + if (args.primaryResource != SPARK_INTERNAL) { + childArgs += ("--jar", args.primaryResource) + } childArgs += ("--class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) } @@ -269,6 +275,9 @@ object SparkSubmit { sysProps.getOrElseUpdate(k, v) } + // Spark properties included on command line take precedence + sysProps ++= args.sparkProperties + (childArgs, childClasspath, sysProps, childMainClass) } @@ -329,7 +338,7 @@ object SparkSubmit { * Return whether the given primary resource represents a user jar. */ private def isUserJar(primaryResource: String): Boolean = { - !isShell(primaryResource) && !isPython(primaryResource) + !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource) } /** @@ -346,6 +355,10 @@ object SparkSubmit { primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL } + private[spark] def isInternal(primaryResource: String): Boolean = { + primaryResource == SPARK_INTERNAL + } + /** * Merge a sequence of comma-separated file lists, some of which may be null to indicate * no files, into a single comma-separated string. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57655aa4c32b1..01d0ae541a66b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -55,6 +55,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + val sparkProperties: HashMap[String, String] = new HashMap[String, String]() parseOpts(args.toList) loadDefaults() @@ -177,6 +178,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile + | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -202,8 +204,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - // Delineates parsing of Spark options from parsing of user options. var inSparkOpts = true + + // Delineates parsing of Spark options from parsing of user options. parse(opts) def parse(opts: Seq[String]): Unit = opts match { @@ -290,6 +293,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { jars = Utils.resolveURIs(value) parse(tail) + case ("--conf" | "-c") :: value :: tail => + value.split("=", 2).toSeq match { + case Seq(k, v) => sparkProperties(k) = v + case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") + } + parse(tail) + case ("--help" | "-h") :: tail => printUsageAndExit(0) @@ -309,7 +319,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { SparkSubmit.printErrorAndExit(errMessage) case v => primaryResource = - if (!SparkSubmit.isShell(v)) { + if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { Utils.resolveURI(v).toString } else { v @@ -349,6 +359,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working | directory of each executor. + | + | --conf PROP=VALUE Arbitrary Spark configuration property. | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a8c9ac072449f..01e7065c17b69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -169,7 +169,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val ui: SparkUI = if (renderUI) { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId) + new SparkUI(conf, appSecManager, replayBus, appId, + HistoryServer.UI_PATH_PREFIX + s"/$appId") // Do not call ui.bind() to avoid creating a new server for each application } else { null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index a958c837c2ff6..d7a3e3f120e67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -75,7 +75,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { "Last Updated") private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { - val uiAddress = "/history/" + info.id + val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}" val startTime = UIUtils.formatDate(info.startTime) val endTime = UIUtils.formatDate(info.endTime) val duration = UIUtils.formatDuration(info.endTime - info.startTime) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 56b38ddfc9313..cacb9da8c947b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -114,7 +114,7 @@ class HistoryServer( attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) val contextHandler = new ServletContextHandler - contextHandler.setContextPath("/history") + contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) contextHandler.addServlet(new ServletHolder(loaderServlet), "/*") attachHandler(contextHandler) } @@ -172,6 +172,8 @@ class HistoryServer( object HistoryServer extends Logging { private val conf = new SparkConf + val UI_PATH_PREFIX = "/history" + def main(argStrings: Array[String]) { SignalLogger.register(log) initSecurity() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index bb1fcc8190fe4..21f8667819c44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -35,6 +35,7 @@ import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI @@ -664,9 +665,10 @@ private[spark] class Master( */ def rebuildSparkUI(app: ApplicationInfo): Boolean = { val appName = app.desc.name + val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" val eventLogDir = app.desc.eventLogDir.getOrElse { // Event logging is not enabled for this application - app.desc.appUiUrl = "/history/not-found" + app.desc.appUiUrl = notFoundBasePath return false } val fileSystem = Utils.getHadoopFileSystem(eventLogDir) @@ -681,13 +683,14 @@ private[spark] class Master( logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = s"/history/not-found?msg=$msg&title=$title" + app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" return false } try { val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) - val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id) + val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", + HistoryServer.UI_PATH_PREFIX + s"/${app.id}") replayBus.replay() appIdToUI(app.id) = ui webUi.attachSparkUI(ui) @@ -702,7 +705,7 @@ private[spark] class Master( var msg = s"Exception in replaying log for application $appName!" logError(msg, e) msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = s"/history/not-found?msg=$msg&exception=$exception&title=$title" + app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title" false } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index aca235a62a6a8..7d96089e52ab9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -25,7 +25,7 @@ import scala.language.existentials import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep] */ @DeveloperApi class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. // CoGroupValue is the intermediate state of each value before being merged in compute. - private type CoGroup = ArrayBuffer[Any] + private type CoGroup = CompactBuffer[Any] private type CoGroupValue = (Any, Int) // Int is dependency number - private type CoGroupCombiner = Seq[CoGroup] + private type CoGroupCombiner = Array[CoGroup] private var serializer: Option[Serializer] = None @@ -114,7 +114,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner: Some[Partitioner] = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] @@ -150,7 +150,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: getCombiner(kv._1)(depNum) += kv._2 } } - new InterruptibleIterator(context, map.iterator) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } else { val map = createExternalMap(numRdds) rddIterators.foreach { case (it, depNum) => @@ -161,7 +162,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled - new InterruptibleIterator(context, map.iterator) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a6b920467283e..c04d162a39616 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -46,6 +46,7 @@ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer +import org.apache.spark.util.collection.CompactBuffer /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -361,12 +362,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. - val createCombiner = (v: V) => ArrayBuffer(v) - val mergeValue = (buf: ArrayBuffer[V], v: V) => buf += v - val mergeCombiners = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => c1 ++ c2 - val bufs = combineByKey[ArrayBuffer[V]]( + val createCombiner = (v: V) => CompactBuffer(v) + val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v + val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 + val bufs = combineByKey[CompactBuffer[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) - bufs.mapValues(_.toIterable) + bufs.asInstanceOf[RDD[(K, Iterable[V])]] } /** @@ -571,11 +572,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) - cg.mapValues { case Seq(vs, w1s, w2s, w3s) => - (vs.asInstanceOf[Seq[V]], - w1s.asInstanceOf[Seq[W1]], - w2s.asInstanceOf[Seq[W2]], - w3s.asInstanceOf[Seq[W3]]) + cg.mapValues { case Array(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Iterable[V]], + w1s.asInstanceOf[Iterable[W1]], + w2s.asInstanceOf[Iterable[W2]], + w3s.asInstanceOf[Iterable[W3]]) } } @@ -589,8 +590,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - cg.mapValues { case Seq(vs, w1s) => - (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W]]) + cg.mapValues { case Array(vs, w1s) => + (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]]) } } @@ -604,10 +605,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - cg.mapValues { case Seq(vs, w1s, w2s) => - (vs.asInstanceOf[Seq[V]], - w1s.asInstanceOf[Seq[W1]], - w2s.asInstanceOf[Seq[W2]]) + cg.mapValues { case Array(vs, w1s, w2s) => + (vs.asInstanceOf[Iterable[V]], + w1s.asInstanceOf[Iterable[W1]], + w2s.asInstanceOf[Iterable[W2]]) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index b5b8a5706deb3..a637d6f15b7e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) * * @param prev RDD to be sampled * @param sampler a random sampler + * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD * @param seed random seed * @tparam T input RDD item type * @tparam U sampled RDD item type @@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], + @transient preservesPartitioning: Boolean, @transient seed: Long = Utils.random.nextLong) extends RDD[U](prev) { + @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None + override def getPartitions: Array[Partition] = { val random = new Random(seed) firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong())) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a1f2827248891..a6abc49c5359e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -354,11 +354,11 @@ abstract class RDD[T: ClassTag]( def sample(withReplacement: Boolean, fraction: Double, seed: Long = Utils.random.nextLong): RDD[T] = { - require(fraction >= 0.0, "Invalid fraction value: " + fraction) + require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { - new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) + new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) } else { - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed) + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) } } @@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed) + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed) }.toArray } @@ -586,6 +586,9 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD by applying a function to each partition of this RDD. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitions[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { @@ -596,6 +599,9 @@ abstract class RDD[T: ClassTag]( /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { @@ -607,6 +613,9 @@ abstract class RDD[T: ClassTag]( * :: DeveloperApi :: * Return a new RDD by applying a function to each partition of this RDD. This is a variant of * mapPartitions that also passes the TaskContext into the closure. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ @DeveloperApi def mapPartitionsWithContext[U: ClassTag]( @@ -689,7 +698,7 @@ abstract class RDD[T: ClassTag]( * a map on the other). */ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { - zipPartitions(other, true) { (thisIter, otherIter) => + zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { def hasNext = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true @@ -745,14 +754,16 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** @@ -1214,7 +1225,7 @@ abstract class RDD[T: ClassTag]( /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSite = Utils.getCallSite - private[spark] def getCreationSite: String = Option(creationSite).map(_.short).getOrElse("") + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ede3c7d9f01ae..acb4c4946eded 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -455,7 +455,7 @@ class DAGScheduler( waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception) => - logInfo("Failed to run " + callSite.short) + logInfo("Failed to run " + callSite.shortForm) throw exception } } @@ -679,7 +679,7 @@ class DAGScheduler( val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.short, partitions.length, allowLocal)) + job.jobId, callSite.shortForm, partitions.length, allowLocal)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a90b0d475c04e..ae6ca9f4e7bf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -63,6 +63,13 @@ private[spark] class EventLoggingListener( // For testing. Keep track of all JSON serialized events that have been logged. private[scheduler] val loggedEvents = new ArrayBuffer[JValue] + /** + * Return only the unique application directory without the base directory. + */ + def getApplicationLogDir(): String = { + name + } + /** * Begin logging events. * If compression is used, log a file that indicates which compression library is used. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 8ec482a6f6d9c..798cbc598d36e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -108,8 +108,8 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - val name = callSite.short - val details = callSite.long + val name = callSite.shortForm + val details = callSite.longForm override def toString = "Stage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index e9f6273bfd9f0..5b897597fa285 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -57,7 +57,7 @@ private[spark] class LocalActor( case StatusUpdate(taskId, state, serializedData) => scheduler.statusUpdate(taskId, state, serializedData) if (TaskState.isFinished(state)) { - freeCores += 1 + freeCores += scheduler.CPUS_PER_TASK reviveOffers() } @@ -68,7 +68,7 @@ private[spark] class LocalActor( def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { - freeCores -= 1 + freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1ce4243194798..fa79b25759153 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,6 +31,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.CompactBuffer import scala.reflect.ClassTag @@ -48,6 +49,7 @@ class KryoSerializer(conf: SparkConf) private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) + private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val registrator = conf.getOption("spark.kryo.registrator") def newKryoOutput() = new KryoOutput(bufferSize) @@ -55,6 +57,7 @@ class KryoSerializer(conf: SparkConf) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() + kryo.setRegistrationRequired(registrationRequired) val classLoader = Thread.currentThread.getContextClassLoader // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. @@ -183,9 +186,11 @@ private[serializer] object KryoSerializer { classOf[GotBlock], classOf[GetBlock], classOf[MapStatus], + classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], - classOf[BoundedPriorityQueue[_]] + classOf[BoundedPriorityQueue[_]], + classOf[SparkConf] ) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 673fc19c060a4..2e7ed7538e6e5 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -43,7 +43,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - private val localDirs: Array[File] = createLocalDirs() + val localDirs: Array[File] = createLocalDirs() if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index e07aa2ee3a5a2..715cc2f4df8dd 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -149,7 +149,7 @@ private[spark] object UIUtils extends Logging { def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource - val commonHeaderNodes = { + def commonHeaderNodes = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 5f45c0ced5ec5..f8b308c981548 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import scala.xml.Node +import scala.xml.Text import java.util.Date @@ -99,19 +100,30 @@ private[ui] class StageTableBase( {s.name} + val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { - +show details - - + +details + ++ + } val stageDataOption = listener.stageIdToData.get(s.stageId) // Too many nested map/flatMaps with options are just annoying to read. Do this imperatively. if (stageDataOption.isDefined && stageDataOption.get.description.isDefined) { val desc = stageDataOption.get.description -
{desc}
{nameLink} {killLink}
+
{desc}
{killLink} {nameLink} {details}
} else {
{killLink} {nameLink} {details}
} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 3448aaaf5724c..bb6079154aafe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -257,7 +257,8 @@ private[spark] object JsonProtocol { val reason = Utils.getFormattedClassName(taskEndReason) val json = taskEndReason match { case fetchFailed: FetchFailed => - val blockManagerAddress = blockManagerIdToJson(fetchFailed.bmAddress) + val blockManagerAddress = Option(fetchFailed.bmAddress). + map(blockManagerIdToJson).getOrElse(JNothing) ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5784e974fbb67..1a4f4eba98ea8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ -private[spark] case class CallSite(short: String, long: String) +private[spark] case class CallSite(shortForm: String, longForm: String) /** * Various utility methods used by Spark. @@ -848,8 +848,8 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt CallSite( - short = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), - long = callStack.take(callStackDepth).mkString("\n")) + shortForm = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), + longForm = callStack.take(callStackDepth).mkString("\n")) } /** Return a string containing part of a file from byte 'start' to 'end'. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala new file mode 100644 index 0000000000000..d44e15e3c97ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/CompactBuffer.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +/** + * An append-only buffer similar to ArrayBuffer, but more memory-efficient for small buffers. + * ArrayBuffer always allocates an Object array to store the data, with 16 entries by default, + * so it has about 80-100 bytes of overhead. In contrast, CompactBuffer can keep up to two + * elements in fields of the main object, and only allocates an Array[AnyRef] if there are more + * entries than that. This makes it more efficient for operations like groupBy where we expect + * some keys to have very few elements. + */ +private[spark] class CompactBuffer[T] extends Seq[T] with Serializable { + // First two elements + private var element0: T = _ + private var element1: T = _ + + // Number of elements, including our two in the main object + private var curSize = 0 + + // Array for extra elements + private var otherElements: Array[AnyRef] = null + + def apply(position: Int): T = { + if (position < 0 || position >= curSize) { + throw new IndexOutOfBoundsException + } + if (position == 0) { + element0 + } else if (position == 1) { + element1 + } else { + otherElements(position - 2).asInstanceOf[T] + } + } + + private def update(position: Int, value: T): Unit = { + if (position < 0 || position >= curSize) { + throw new IndexOutOfBoundsException + } + if (position == 0) { + element0 = value + } else if (position == 1) { + element1 = value + } else { + otherElements(position - 2) = value.asInstanceOf[AnyRef] + } + } + + def += (value: T): CompactBuffer[T] = { + val newIndex = curSize + if (newIndex == 0) { + element0 = value + curSize = 1 + } else if (newIndex == 1) { + element1 = value + curSize = 2 + } else { + growToSize(curSize + 1) + otherElements(newIndex - 2) = value.asInstanceOf[AnyRef] + } + this + } + + def ++= (values: TraversableOnce[T]): CompactBuffer[T] = { + values match { + // Optimize merging of CompactBuffers, used in cogroup and groupByKey + case compactBuf: CompactBuffer[T] => + val oldSize = curSize + // Copy the other buffer's size and elements to local variables in case it is equal to us + val itsSize = compactBuf.curSize + val itsElements = compactBuf.otherElements + growToSize(curSize + itsSize) + if (itsSize == 1) { + this(oldSize) = compactBuf.element0 + } else if (itsSize == 2) { + this(oldSize) = compactBuf.element0 + this(oldSize + 1) = compactBuf.element1 + } else if (itsSize > 2) { + this(oldSize) = compactBuf.element0 + this(oldSize + 1) = compactBuf.element1 + // At this point our size is also above 2, so just copy its array directly into ours. + // Note that since we added two elements above, the index in this.otherElements that we + // should copy to is oldSize. + System.arraycopy(itsElements, 0, otherElements, oldSize, itsSize - 2) + } + + case _ => + values.foreach(e => this += e) + } + this + } + + override def length: Int = curSize + + override def size: Int = curSize + + override def iterator: Iterator[T] = new Iterator[T] { + private var pos = 0 + override def hasNext: Boolean = pos < curSize + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException + } + pos += 1 + apply(pos - 1) + } + } + + /** Increase our size to newSize and grow the backing array if needed. */ + private def growToSize(newSize: Int): Unit = { + if (newSize < 0) { + throw new UnsupportedOperationException("Can't grow buffer past Int.MaxValue elements") + } + val capacity = if (otherElements != null) otherElements.length + 2 else 2 + if (newSize > capacity) { + var newArrayLen = 8 + while (newSize - 2 > newArrayLen) { + newArrayLen *= 2 + if (newArrayLen == Int.MinValue) { + // Prevent overflow if we double from 2^30 to 2^31, which will become Int.MinValue. + // Note that we set the new array length to Int.MaxValue - 2 so that our capacity + // calculation above still gives a positive integer. + newArrayLen = Int.MaxValue - 2 + } + } + val newArray = new Array[AnyRef](newArrayLen) + if (otherElements != null) { + System.arraycopy(otherElements, 0, newArray, 0, otherElements.length) + } + otherElements = newArray + } + curSize = newSize + } +} + +private[spark] object CompactBuffer { + def apply[T](): CompactBuffer[T] = new CompactBuffer[T] + + def apply[T](value: T): CompactBuffer[T] = { + val buf = new CompactBuffer[T] + buf += value + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 71ab2a3e3bef4..be8f6529f7a1c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -106,6 +106,7 @@ class ExternalAppendOnlyMap[K, V, C]( private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() + private val threadId = Thread.currentThread().getId /** * Insert the given key and value into the map. @@ -128,7 +129,6 @@ class ExternalAppendOnlyMap[K, V, C]( // Atomically check whether there is sufficient memory in the global pool for // this map to grow and, if possible, allocate the required amount shuffleMemoryMap.synchronized { - val threadId = Thread.currentThread().getId val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) val availableMemory = maxMemoryThreshold - (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) @@ -153,8 +153,8 @@ class ExternalAppendOnlyMap[K, V, C]( */ private def spill(mapSize: Long) { spillCount += 1 - logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) var objectsWritten = 0 diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index fc00458083a33..d1cb2d9d3a53b 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -156,15 +156,20 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("CoGroupedRDD") { val longLineageRDD1 = generateFatPairRDD() + + // Collect the RDD as sequences instead of arrays to enable equality tests in testRDD + val seqCollectFunc = (rdd: RDD[(Int, Array[Iterable[Int]])]) => + rdd.map{case (p, a) => (p, a.toSeq)}.collect(): Any + testRDD(rdd => { CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }) + }, seqCollectFunc) val longLineageRDD2 = generateFatPairRDD() testRDDPartitions(rdd => { CheckpointSuite.cogroup( longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }) + }, seqCollectFunc) } test("ZippedPartitionsRDD") { @@ -235,12 +240,19 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(rdd.partitions.size === 0) } + def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + /** * Test checkpointing of the RDD generated by the given operation. It tests whether the * serialized size of the RDD is reduce after checkpointing or not. This function should be called * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U]) { + def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U], + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -258,13 +270,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) operatedRDD.checkpoint() - val result = operatedRDD.collect() + val result = collectFunc(operatedRDD) operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the checkpoint file has been created - assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) @@ -279,7 +291,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(operatedRDD.partitions.length === numPartitions) // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) + assert(collectFunc(operatedRDD) === result) // Test whether serialized size of the RDD has reduced. logInfo("Size of " + rddType + @@ -289,7 +301,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { "Size of " + rddType + " did not reduce after checkpointing " + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) - } /** @@ -300,8 +311,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * This function should be called only those RDD whose partitions refer to parent RDD's * partitions (i.e., do not call it on simple RDD like MappedRDD). * + * @param op an operation to run on the RDD + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something that supports == */ - def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U]) { + def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U], + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _) { // Generate the final RDD using given RDD operation val baseRDD = generateFatRDD() val operatedRDD = op(baseRDD) @@ -316,13 +331,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one - val result = operatedRDD.collect() // force checkpointing + val result = collectFunc(operatedRDD) // force checkpointing operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) + assert(collectFunc(operatedRDD) === result) // Test whether serialized size of the partitions has reduced logInfo("Size of partitions of " + rddType + @@ -436,7 +451,7 @@ object CheckpointSuite { new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part - ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 237e644b48e49..eae67c7747e82 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -176,7 +176,9 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) - val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() + val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)) + .map(p => (p._1, p._2.map(_.toArray))) + .collectAsMap() assert(results(1)(0).length === 3) assert(results(1)(0).contains(1)) diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 1fde4badda949..fb18c3ebfe46f 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -70,7 +70,7 @@ package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) val rddCreationSite = rdd.getCreationSite - val curCallSite = sc.getCallSite().short // note: 2 lines after definition of "rdd" + val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { case CALL_SITE_REGEX(func, file, line) => { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 565c53e9529ff..f497a5e0a14f0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -120,6 +120,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -139,6 +140,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") + sysProps("spark.shuffle.spill") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") } @@ -156,6 +158,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -176,6 +179,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") + sysProps("spark.shuffle.spill") should be ("false") } test("handles standalone cluster mode") { @@ -186,6 +190,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--supervise", "--driver-memory", "4g", "--driver-cores", "5", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -195,9 +200,10 @@ class SparkSubmitSuite extends FunSuite with Matchers { childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") mainClass should be ("org.apache.spark.deploy.Client") classpath should have size (0) - sysProps should have size (2) + sysProps should have size (3) sysProps.keys should contain ("spark.jars") sysProps.keys should contain ("SPARK_SUBMIT") + sysProps("spark.shuffle.spill") should be ("false") } test("handles standalone client mode") { @@ -208,6 +214,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -218,6 +225,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") + sysProps("spark.shuffle.spill") should be ("false") } test("handles mesos client mode") { @@ -228,6 +236,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", + "--conf", "spark.shuffle.spill=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -238,6 +247,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") + sysProps("spark.shuffle.spill") should be ("false") } test("launch simple application with spark-submit") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 5dd8de319a654..a0483886f8db3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -43,7 +43,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { test("seed distribution") { val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) val sampler = new MockSampler - val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L) + val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L) assert(sample.distinct().count == 2, "Seeds must be different.") } @@ -52,7 +52,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { // We want to make sure there are no concurrency issues. val rdd = sc.parallelize(0 until 111, 10) for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) { - val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler) + val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true) sampled.zip(sampled).count() } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 2924de112934c..6654ec2d7c656 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -523,6 +523,15 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedTopK === nums.sorted(ord).take(5)) } + test("sample preserves partitioner") { + val partitioner = new HashPartitioner(2) + val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner) + for (withReplacement <- Seq(true, false)) { + val sampled = rdd.sample(withReplacement, 1.0) + assert(sampled.partitioner === rdd.partitioner) + } + } + test("takeSample") { val n = 1000000 val data = sc.parallelize(1 to n, 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 86b443b18f2a6..c52368b5514db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -475,6 +475,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY assert(manager.myLocalityLevels.sameElements( Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) + FakeRackUtil.cleanUp() } test("test RACK_LOCAL tasks") { @@ -505,6 +506,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2 // Task 1 can be scheduled with RACK_LOCAL assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1) + FakeRackUtil.cleanUp() } test("do not emit warning when serialized task is small") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala new file mode 100644 index 0000000000000..6c956d93dc80d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/CompactBufferSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + +class CompactBufferSuite extends FunSuite { + test("empty buffer") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + assert(b.size === 0) + assert(b.iterator.toList === Nil) + intercept[IndexOutOfBoundsException] { b(0) } + intercept[IndexOutOfBoundsException] { b(1) } + intercept[IndexOutOfBoundsException] { b(2) } + intercept[IndexOutOfBoundsException] { b(-1) } + } + + test("basic inserts") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + for (i <- 0 until 1000) { + b += i + assert(b.size === i + 1) + assert(b(i) === i) + } + assert(b.iterator.toList === (0 until 1000).toList) + assert(b.iterator.toList === (0 until 1000).toList) + assert(b.size === 1000) + } + + test("adding sequences") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + + // Add some simple lists and iterators + b ++= List(0) + assert(b.size === 1) + assert(b.iterator.toList === List(0)) + b ++= Iterator(1) + assert(b.size === 2) + assert(b.iterator.toList === List(0, 1)) + b ++= List(2) + assert(b.size === 3) + assert(b.iterator.toList === List(0, 1, 2)) + b ++= Iterator(3, 4, 5, 6, 7, 8, 9) + assert(b.size === 10) + assert(b.iterator.toList === (0 until 10).toList) + + // Add CompactBuffers + val b2 = new CompactBuffer[Int] + b2 ++= 0 until 10 + b ++= b2 + assert(b.iterator.toList === (1 to 2).flatMap(i => 0 until 10).toList) + b ++= b2 + assert(b.iterator.toList === (1 to 3).flatMap(i => 0 until 10).toList) + b ++= b2 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList) + + // Add some small CompactBuffers as well + val b3 = new CompactBuffer[Int] + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList) + b3 += 0 + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0)) + b3 += 1 + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1)) + b3 += 2 + b ++= b3 + assert(b.iterator.toList === (1 to 4).flatMap(i => 0 until 10).toList ++ List(0, 0, 1, 0, 1, 2)) + } + + test("adding the same buffer to itself") { + val b = new CompactBuffer[Int] + assert(b.size === 0) + assert(b.iterator.toList === Nil) + b += 1 + assert(b.toList === List(1)) + for (j <- 1 until 8) { + b ++= b + assert(b.size === (1 << j)) + assert(b.iterator.toList === (1 to (1 << j)).map(i => 1).toList) + } + } +} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 38830103d1e8d..33de24d1ae6d7 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ --batch-mode release:prepare @@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ release:perform cd .. @@ -111,10 +111,10 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" make_binary_release "hadoop2" \ - "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" diff --git a/dev/run-tests b/dev/run-tests index 51e4def0f835a..98ec969dc1b37 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -65,7 +65,7 @@ echo "=========================================================================" # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \ + echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \ assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" else echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ diff --git a/dev/scalastyle b/dev/scalastyle index a02d06912f238..d9f2b91a3a091 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/configuration.md b/docs/configuration.md index a70007c165442..dac8bb1d52468 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -42,13 +42,15 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My fancy app" --master local[4] myApp.jar +./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false + --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit) tool support two ways to load configurations dynamically. The first are command line options, -such as `--master`, as shown above. Running `./bin/spark-submit --help` will show the entire list -of options. +such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` +flag, but uses special flags for properties that play a part in launching the Spark application. +Running `./bin/spark-submit --help` will show the entire list of these options. `bin/spark-submit` will also read configuration options from `conf/spark-defaults.conf`, in which each line consists of a key and a value separated by whitespace. For example: @@ -195,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful Spark's dependencies and user dependencies. It is currently an experimental feature. + + spark.python.worker.memory + 512m + + Amount of memory to use per python worker process during aggregation, in the same + format as JVM memory strings (e.g. 512m, 2g). If the memory + used during aggregation goes above this amount, it will spill the data into disks. + + #### Shuffle Behavior @@ -388,6 +399,17 @@ Apart from these, the following properties are also available, and may be useful case. + + spark.kryo.registrationRequired + false + + Whether to require registration with Kryo. If set to 'true', Kryo will throw an exception + if an unregistered class is serialized. If set to false (the default), Kryo will write + unregistered class names along with each object. Writing class names can cause + significant performance overhead, so enabling this option can enforce strictly that a + user has not omitted classes from registration. + + spark.kryoserializer.buffer.mb 2 @@ -497,9 +519,9 @@ Apart from these, the following properties are also available, and may be useful spark.hadoop.validateOutputSpecs true - If set to true, validates the output specification (e.g. checking if the output directory already exists) - used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing - output directories. We recommend that users do not disable this except if trying to achieve compatibility with + If set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. @@ -861,7 +883,7 @@ Apart from these, the following properties are also available, and may be useful #### Cluster Managers -Each cluster manager in Spark has additional configuration options. Configurations +Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: * [YARN](running-on-yarn.html#configuration) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 38728534a46e0..36d642f2923b2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD // Define the schema using a case class. -// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, +// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, // you can use custom classes that implement the Product interface. case class Person(name: String, age: Int) @@ -548,7 +548,6 @@ results = hiveContext.hql("FROM src SELECT key, value").collect() - # Writing Language-Integrated Relational Queries **Language-Integrated queries are currently only supported in Scala.** @@ -573,4 +572,199 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres evaluated by the SQL execution engine. A full list of the functions supported can be found in the [ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). - \ No newline at end of file + + +## Running the Thrift JDBC server + +The Thrift JDBC server implemented here corresponds to the [`HiveServer2`] +(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test +the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive +you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver` +for maven). + +To start the JDBC server, run the following in the Spark directory: + + ./sbin/start-thriftserver.sh + +The default port the server listens on is 10000. You may run +`./sbin/start-thriftserver.sh --help` for a complete list of all available +options. Now you can use beeline to test the Thrift JDBC server: + + ./bin/beeline + +Connect to the JDBC server in beeline with: + + beeline> !connect jdbc:hive2://localhost:10000 + +Beeline will ask you for a username and password. In non-secure mode, simply enter the username on +your machine and a blank password. For secure mode, please follow the instructions given in the +[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients) + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. + +You may also use the beeline script comes with Hive. + +### Migration Guide for Shark Users + +#### Reducer number + +In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark +SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value +is 200. Users may customize this property via `SET`: + +``` +SET spark.sql.shuffle.partitions=10; +SELECT page, count(*) c FROM logs_last_month_cached +GROUP BY page ORDER BY c DESC LIMIT 10; +``` + +You may also put this property in `hive-site.xml` to override the default value. + +For now, the `mapred.reduce.tasks` property is still recognized, and is converted to +`spark.sql.shuffle.partitions` automatically. + +#### Caching + +The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no +longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to +let user control table caching explicitly: + +``` +CACHE TABLE logs_last_month; +UNCACHE TABLE logs_last_month; +``` + +**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", +but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be +cached, you may simply count the table immediately after executing `CACHE TABLE`: + +``` +CACHE TABLE logs_last_month; +SELECT COUNT(1) FROM logs_last_month; +``` + +Several caching related features are not supported yet: + +* User defined partition level cache eviction policy +* RDD reloading +* In-memory cache write through policy + +### Compatibility with Apache Hive + +#### Deploying in Exising Hive Warehouses + +Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +installations. You do not need to modify your existing Hive Metastore or change the data placement +or partitioning of your tables. + +#### Supported Hive Features + +Spark SQL supports the vast majority of Hive features, such as: + +* Hive query statements, including: + * `SELECT` + * `GROUP BY + * `ORDER BY` + * `CLUSTER BY` + * `SORT BY` +* All Hive operators, including: + * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc) + * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc) + * Logical operators (`AND`, `&&`, `OR`, `||`, etc) + * Complex type constructors + * Mathemtatical functions (`sign`, `ln`, `cos`, etc) + * String functions (`instr`, `length`, `printf`, etc) +* User defined functions (UDF) +* User defined aggregation functions (UDAF) +* User defined serialization formats (SerDe's) +* Joins + * `JOIN` + * `{LEFT|RIGHT|FULL} OUTER JOIN` + * `LEFT SEMI JOIN` + * `CROSS JOIN` +* Unions +* Sub queries + * `SELECT col FROM ( SELECT a + b AS col from t1) t2` +* Sampling +* Explain +* Partitioned tables +* All Hive DDL Functions, including: + * `CREATE TABLE` + * `CREATE TABLE AS SELECT` + * `ALTER TABLE` +* Most Hive Data types, including: + * `TINYINT` + * `SMALLINT` + * `INT` + * `BIGINT` + * `BOOLEAN` + * `FLOAT` + * `DOUBLE` + * `STRING` + * `BINARY` + * `TIMESTAMP` + * `ARRAY<>` + * `MAP<>` + * `STRUCT<>` + +#### Unsupported Hive Functionality + +Below is a list of Hive features that we don't support yet. Most of these features are rarely used +in Hive deployments. + +**Major Hive Features** + +* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL + doesn't support buckets yet. + +**Esoteric Hive Features** + +* Tables with partitions using different input formats: In Spark SQL, all table partitions need to + have the same input format. +* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions + (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. +* `UNIONTYPE` +* Unique join +* Single query multi insert +* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at + the moment. + +**Hive Input/Output Formats** + +* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat. +* Hadoop archive + +**Hive Optimizations** + +A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are +not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +releases of Spark SQL. + +* Block level bitmap indexes and virtual columns (used to build indexes) +* Automatically convert a join to map join: For joining a large table with multiple small tables, + Hive automatically converts the join into a map join. We are adding this auto conversion in the + next release. +* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you + need to control the degree of parallelism post-shuffle using "SET + spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the + next release. +* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still + launches tasks to compute the result. +* Skew data flag: Spark SQL does not follow the skew data flags in Hive. +* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. +* Merge multiple small files for query results: if the result output contains multiple small files, + Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS + metadata. Spark SQL does not support that. + +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e05883072bfa8..45b70b1a5457a 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -33,6 +33,7 @@ dependencies, and can support different cluster managers and deploy modes that S --class --master \ --deploy-mode \ + --conf = \ ... # other options \ [application-arguments] @@ -43,6 +44,7 @@ Some of the commonly used options are: * `--class`: The entry point for your application (e.g. `org.apache.spark.examples.SparkPi`) * `--master`: The [master URL](#master-urls) for the cluster (e.g. `spark://23.195.26.187:7077`) * `--deploy-mode`: Whether to deploy your driver on the worker nodes (`cluster`) or locally as an external client (`client`) (default: `client`)* +* `--conf`: Arbitrary Spark configuration property in key=value format. For values that contain spaces wrap "key=value" in quotes (as shown). * `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. * `application-arguments`: Arguments passed to the main method of your main class, if any diff --git a/examples/pom.xml b/examples/pom.xml index bd1c387c2eb91..c4ed0f5a6a02b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-examples_2.10 - examples + examples jar Spark Project Examples diff --git a/external/flume/pom.xml b/external/flume/pom.xml index e6b3cc36702c8..9f680b27c3308 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-flume_2.10 - streaming-flume + streaming-flume jar Spark Project External Flume diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4762c50685a93..25a5c0a4d7d77 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-kafka_2.10 - streaming-kafka + streaming-kafka jar Spark Project External Kafka diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 32c530e600ce0..f31ed655f6779 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-mqtt_2.10 - streaming-mqtt + streaming-mqtt jar Spark Project External MQTT diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 637adb0f00da0..56bb24c2a072e 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-twitter_2.10 - streaming-twitter + streaming-twitter jar Spark Project External Twitter diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 5ea2e5549d7df..4eacc47da5699 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -63,7 +63,8 @@ class TwitterReceiver( storageLevel: StorageLevel ) extends Receiver[Status](storageLevel) with Logging { - private var twitterStream: TwitterStream = _ + @volatile private var twitterStream: TwitterStream = _ + @volatile private var stopped = false def onStart() { try { @@ -78,7 +79,9 @@ class TwitterReceiver( def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} def onException(e: Exception) { - restart("Error receiving tweets", e) + if (!stopped) { + restart("Error receiving tweets", e) + } } }) @@ -91,12 +94,14 @@ class TwitterReceiver( } setTwitterStream(newTwitterStream) logInfo("Twitter receiver started") + stopped = false } catch { case e: Exception => restart("Error starting Twitter stream", e) } } def onStop() { + stopped = true setTwitterStream(null) logInfo("Twitter receiver stopped") } diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index e4d758a04a4cd..54b0242c54e78 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming-zeromq_2.10 - streaming-zeromq + streaming-zeromq jar Spark Project External ZeroMQ diff --git a/graphx/pom.xml b/graphx/pom.xml index 7e3bcf29dcfbc..6dd52fc618b1e 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-graphx_2.10 - graphx + graphx jar Spark Project GraphX diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 3507f358bfb40..fa4b891754c40 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -344,7 +344,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * * {{{ * val rawGraph: Graph[_, _] = Graph.textFile("webgraph") - * val outDeg: RDD[(VertexId, Int)] = rawGraph.outDegrees() + * val outDeg: RDD[(VertexId, Int)] = rawGraph.outDegrees * val graph = rawGraph.outerJoinVertices(outDeg) { * (vid, data, optDeg) => optDeg.getOrElse(0) * } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index f97f329c0e832..1948c978c30bf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -35,9 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: Kryo) { kryo.register(classOf[Edge[Object]]) - kryo.register(classOf[MessageToPartition[Object]]) - kryo.register(classOf[VertexBroadcastMsg[Object]]) - kryo.register(classOf[RoutingTableMessage]) kryo.register(classOf[(VertexId, Object)]) kryo.register(classOf[EdgePartition[Object, Object]]) kryo.register(classOf[BitSet]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index edd5b79da1522..02afaa987d40d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -198,10 +198,10 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * * {{{ * val rawGraph: Graph[Int, Int] = GraphLoader.edgeListFile(sc, "webgraph") - * .mapVertices(v => 0) - * val outDeg: RDD[(Int, Int)] = rawGraph.outDegrees - * val graph = rawGraph.leftJoinVertices[Int,Int](outDeg, - * (v, deg) => deg ) + * .mapVertices((_, _) => 0) + * val outDeg = rawGraph.outDegrees + * val graph = rawGraph.joinVertices[Int](outDeg) + * ((_, _, outDeg) => outDeg) * }}} * */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index ccdaa82eb9162..33f35cfb69a26 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -26,7 +26,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ -import org.apache.spark.graphx.impl.MsgRDDFunctions._ import org.apache.spark.graphx.util.BytecodeUtils @@ -83,15 +82,13 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( val vdTag = classTag[VD] val newEdges = edges.withPartitionsRDD(edges.map { e => val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions) - - // Should we be using 3-tuple or an optimized class - new MessageToPartition(part, (e.srcId, e.dstId, e.attr)) + (part, (e.srcId, e.dstId, e.attr)) } .partitionBy(new HashPartitioner(numPartitions)) .mapPartitionsWithIndex( { (pid, iter) => val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag) iter.foreach { message => - val data = message.data + val data = message._2 builder.add(data._1, data._2, data._3) } val edgePartition = builder.toEdgePartition diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index d85afa45b1264..5318b8da6412a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -25,82 +25,6 @@ import org.apache.spark.graphx.{PartitionID, VertexId} import org.apache.spark.rdd.{ShuffledRDD, RDD} -private[graphx] -class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T]( - @transient var partition: PartitionID, - var vid: VertexId, - var data: T) - extends Product2[PartitionID, (VertexId, T)] with Serializable { - - override def _1 = partition - - override def _2 = (vid, data) - - override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]] -} - - -/** - * A message used to send a specific value to a partition. - * @param partition index of the target partition. - * @param data value to send - */ -private[graphx] -class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T]( - @transient var partition: PartitionID, - var data: T) - extends Product2[PartitionID, T] with Serializable { - - override def _1 = partition - - override def _2 = data - - override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]] -} - - -private[graphx] -class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) { - def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = { - val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]]( - self, partitioner) - - // Set a custom serializer if the data is of int or double type. - if (classTag[T] == ClassTag.Int) { - rdd.setSerializer(new IntVertexBroadcastMsgSerializer) - } else if (classTag[T] == ClassTag.Long) { - rdd.setSerializer(new LongVertexBroadcastMsgSerializer) - } else if (classTag[T] == ClassTag.Double) { - rdd.setSerializer(new DoubleVertexBroadcastMsgSerializer) - } - rdd - } -} - - -private[graphx] -class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) { - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = { - new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner) - } - -} - -private[graphx] -object MsgRDDFunctions { - implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = { - new MsgRDDFunctions(rdd) - } - - implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = { - new VertexBroadcastMsgRDDFunctions(rdd) - } -} - private[graphx] class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 502b112d31c2e..a565d3b28bf52 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -/** - * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that - * the edge partition references `vid` in the specified `position` (src, dst, or both). -*/ -private[graphx] -class RoutingTableMessage( - var vid: VertexId, - var pid: PartitionID, - var position: Byte) - extends Product2[VertexId, (PartitionID, Byte)] with Serializable { - override def _1 = vid - override def _2 = (pid, position) - override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage] -} +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage private[graphx] class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage]( + new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage]( self, partitioner).setSerializer(new RoutingTableMessageSerializer) } } @@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions { private[graphx] object RoutingTablePartition { + /** + * A message from an edge partition to a vertex specifying the position in which the edge + * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower + * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + */ + type RoutingTableMessage = (VertexId, Int) + + private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = { + val positionUpper2 = position << 30 + val pidLower30 = pid & 0x3FFFFFFF + (vid, positionUpper2 | pidLower30) + } + + private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1 + private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF + private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte + val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty) /** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */ @@ -77,7 +81,9 @@ object RoutingTablePartition { map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => - new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2) + val vid = vidAndPosition._1 + val position = vidAndPosition._2 + toMessage(vid, pid, position) } } @@ -88,9 +94,12 @@ object RoutingTablePartition { val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) for (msg <- iter) { - pid2vid(msg.pid) += msg.vid - srcFlags(msg.pid) += (msg.position & 0x1) != 0 - dstFlags(msg.pid) += (msg.position & 0x2) != 0 + val vid = vidFromMessage(msg) + val pid = pidFromMessage(msg) + val position = positionFromMessage(msg) + pid2vid(pid) += vid + srcFlags(pid) += (position & 0x1) != 0 + dstFlags(pid) += (position & 0x2) != 0 } new RoutingTablePartition(pid2vid.zipWithIndex.map { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 033237f597216..3909efcdfc993 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -24,9 +24,11 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.graphx._ import org.apache.spark.serializer._ +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage + private[graphx] class RoutingTableMessageSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { @@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { new ShuffleSerializationStream(s) { def writeObject[T: ClassTag](t: T): SerializationStream = { val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg.vid, optimizePositive = false) - writeUnsignedVarInt(msg.pid) - // TODO: Write only the bottom two bits of msg.position - s.write(msg.position) + writeVarLong(msg._1, optimizePositive = false) + writeInt(msg._2) this } } @@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable { new ShuffleDeserializationStream(s) { override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - val c = s.read() - if (c == -1) throw new EOFException - new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T] + val b = readInt() + (a, b).asInstanceOf[T] } } } @@ -76,78 +74,6 @@ class VertexIdMsgSerializer extends Serializer with Serializable { } } -/** A special shuffle serializer for VertexBroadcastMessage[Int]. */ -private[graphx] -class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[VertexBroadcastMsg[Int]] - writeVarLong(msg.vid, optimizePositive = false) - writeInt(msg.data) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readInt() - new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for VertexBroadcastMessage[Long]. */ -private[graphx] -class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[VertexBroadcastMsg[Long]] - writeVarLong(msg.vid, optimizePositive = false) - writeLong(msg.data) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readLong() - new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for VertexBroadcastMessage[Double]. */ -private[graphx] -class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[VertexBroadcastMsg[Double]] - writeVarLong(msg.vid, optimizePositive = false) - writeDouble(msg.data) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readDouble() - new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T] - } - } - } -} - /** A special shuffle serializer for AggregationMessage[Int]. */ private[graphx] class IntAggMsgSerializer extends Serializer with Serializable { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala index ff17edeaf8f16..6aab28ff05355 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala @@ -30,7 +30,7 @@ package object graphx { */ type VertexId = Long - /** Integer identifer of a graph partition. */ + /** Integer identifer of a graph partition. Must be less than 2^30. */ // TODO: Consider using Char. type PartitionID = Int diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala index 91caa6b605a1e..864cb1fdf0022 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala @@ -26,75 +26,11 @@ import org.scalatest.FunSuite import org.apache.spark._ import org.apache.spark.graphx.impl._ -import org.apache.spark.graphx.impl.MsgRDDFunctions._ import org.apache.spark.serializer.SerializationStream class SerializerSuite extends FunSuite with LocalSparkContext { - test("IntVertexBroadcastMsgSerializer") { - val outMsg = new VertexBroadcastMsg[Int](3, 4, 5) - val bout = new ByteArrayOutputStream - val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject() - val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject() - assert(outMsg.vid === inMsg1.vid) - assert(outMsg.vid === inMsg2.vid) - assert(outMsg.data === inMsg1.data) - assert(outMsg.data === inMsg2.data) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("LongVertexBroadcastMsgSerializer") { - val outMsg = new VertexBroadcastMsg[Long](3, 4, 5) - val bout = new ByteArrayOutputStream - val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject() - val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject() - assert(outMsg.vid === inMsg1.vid) - assert(outMsg.vid === inMsg2.vid) - assert(outMsg.data === inMsg1.data) - assert(outMsg.data === inMsg2.data) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("DoubleVertexBroadcastMsgSerializer") { - val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0) - val bout = new ByteArrayOutputStream - val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject() - val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject() - assert(outMsg.vid === inMsg1.vid) - assert(outMsg.vid === inMsg2.vid) - assert(outMsg.data === inMsg1.data) - assert(outMsg.data === inMsg2.data) - - intercept[EOFException] { - inStrm.readObject() - } - } - test("IntAggMsgSerializer") { val outMsg = (4: VertexId, 5) val bout = new ByteArrayOutputStream @@ -152,15 +88,6 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } } - test("TestShuffleVertexBroadcastMsg") { - withSpark { sc => - val bmsgs = sc.parallelize(0 until 100, 10).map { pid => - new VertexBroadcastMsg[Int](pid, pid, pid) - } - bmsgs.partitionBy(new HashPartitioner(3)).collect() - } - } - test("variable long encoding") { def testVarLongEncoding(v: Long, optimizePositive: Boolean) { val bout = new ByteArrayOutputStream diff --git a/mllib/pom.xml b/mllib/pom.xml index 92b07e2357db1..f27cf520dc9fa 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-mllib_2.10 - mllib + mllib jar Spark Project ML Library diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 079743742d86d..1af40de2c7fcf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -103,11 +103,11 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) - val agg = counts.values.mapPartitions({ iter => + val agg = counts.values.mapPartitions { iter => val agg = new BinaryLabelCounter() iter.foreach(agg += _) Iterator(agg) - }, preservesPartitioning = true).collect() + }.collect() val partitionwiseCumulativeCounts = agg.scanLeft(new BinaryLabelCounter())( (agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index f4c403bc7861c..8c2b044ea73f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -377,9 +377,9 @@ class RowMatrix( s"Only support dense matrix at this time but found ${B.getClass.getName}.") val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) - val AB = rows.mapPartitions({ iter => + val AB = rows.mapPartitions { iter => val Bi = Bb.value - iter.map(row => { + iter.map { row => val v = BDV.zeros[Double](k) var i = 0 while (i < k) { @@ -387,8 +387,8 @@ class RowMatrix( i += 1 } Vectors.fromBreeze(v) - }) - }, preservesPartitioning = true) + } + } new RowMatrix(AB, nRows, B.numCols) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 15e8855db6ca7..5356790cb5339 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -430,7 +430,7 @@ class ALS private ( val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner) val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner) Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, true) + }, preservesPartitioning = true) val inLinks = links.mapValues(_._1) val outLinks = links.mapValues(_._2) inLinks.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 88de2c82479b7..1f7de630e778c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -122,6 +122,10 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { val partitioner = new HashPartitioner(input.partitions.size) val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) - cogrouped.map { case (_, values: Seq[Seq[Double]]) => new DenseVector(values.flatten.toArray) } + cogrouped.map { + case (_, values: Array[Iterable[_]]) => + val doubles = values.asInstanceOf[Array[Iterable[Double]]] + new DenseVector(doubles.flatten.toArray) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index aaf92a1a8869a..30de24ad89f98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -264,8 +264,8 @@ object MLUtils { (1 to numFolds).map { fold => val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) - val validation = new PartitionwiseSampledRDD(rdd, sampler, seed) - val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed) + val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) + val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) (training, validation) }.toArray } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 9d16182f9d8c4..94db1dc183230 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -20,8 +20,26 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { + + // TODO: move utility functions to TestingUtils. + + def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = { + actual.zip(expected).forall { case (x1, x2) => + x1.almostEquals(x2) + } + } + + def elementsAlmostEqual( + actual: Seq[(Double, Double)], + expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = { + actual.zip(expected).forall { case ((x1, y1), (x2, y2)) => + x1.almostEquals(x2) && y1.almostEquals(y2) + } + } + test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) @@ -41,14 +59,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(metrics.thresholds().collect().toSeq === threshold) - assert(metrics.roc().collect().toSeq === rocCurve) - assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) - assert(metrics.pr().collect().toSeq === prCurve) - assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) - assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) - assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) - assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) - assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) + assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold)) + assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve)) + assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve))) + assert(elementsAlmostEqual(metrics.pr().collect(), prCurve)) + assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve))) + assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))) + assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))) + assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision))) + assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall))) } } diff --git a/pom.xml b/pom.xml index de6fae8c4b932..96a0c60d24de6 100644 --- a/pom.xml +++ b/pom.xml @@ -95,6 +95,7 @@ sql/catalyst sql/core sql/hive + sql/hive-thriftserver repl assembly external/twitter @@ -253,9 +254,9 @@ 3.3.2 - commons-codec - commons-codec - 1.5 + commons-codec + commons-codec + 1.5 com.google.code.findbugs diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5e5ddd227aab6..e9220db6b1f9a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -32,108 +32,83 @@ import com.typesafe.tools.mima.core._ */ object MimaExcludes { - def excludes(version: String) = version match { - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - closures.map(method => ProblemFilters.exclude[MissingMethodProblem](method)) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$debugChildren$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$firstDebugString$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$shuffleDebugString$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$debugString$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" - + "createZero$1") - ) ++ - Seq( - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Variance.calculate") - ) - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") - case _ => Seq() - } - - private val closures = Seq( - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$mergeMaps$1", - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$countPartition$1", - "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$distributePartition$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeValue$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeToFile$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$reducePartition$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeShard$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeCombiners$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$process$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$createCombiner$1", - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeMaps$1" - ) + def excludes(version: String) = + version match { + case v if v.startsWith("1.1") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values + // for countApproxDistinct* functions, which does not work in Java. We later removed + // them, and use the following to tell Mima to not care about them. + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry") + ) ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) + case v if v.startsWith("1.0") => + Seq( + MimaBuild.excludeSparkPackage("api.java"), + MimaBuild.excludeSparkPackage("mllib"), + MimaBuild.excludeSparkPackage("streaming") + ) ++ + MimaBuild.excludeSparkClass("rdd.ClassTags") ++ + MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ + MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ + MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ + MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ + MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ + MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + case _ => Seq() + } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1abbdff06c9bb..162bb8900cb90 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConversions._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ +import sbtunidoc.Plugin.genjavadocSettings import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings} import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -29,11 +30,10 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming, - streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, - streamingZeromq) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql", - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql, + streaming, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = + Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + "spark", "sql", "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = @@ -100,7 +100,7 @@ object SparkBuild extends PomBuild { Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1))) - case _ => + case _ => } override val userPropertiesMap = System.getProperties.toMap @@ -108,7 +108,7 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ Seq ( + lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ genjavadocSettings ++ Seq ( javaHome := Properties.envOrNone("JAVA_HOME").map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, @@ -158,7 +158,7 @@ object SparkBuild extends PomBuild { /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ // TODO: Add Sql to mima checks - allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)). + allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)). foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) /* Enable Assembly for all assembly projects */ diff --git a/python/epydoc.conf b/python/epydoc.conf index b73860bad8263..51c0faf359939 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -35,4 +35,4 @@ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests pyspark.rddsampler pyspark.daemon pyspark.mllib._common - pyspark.mllib.tests + pyspark.mllib.tests pyspark.shuffle diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b50590ab3b444..b4c82f519bd53 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -100,6 +100,12 @@ def set(self, key, value): self._jconf.set(key, unicode(value)) return self + def setIfMissing(self, key, value): + """Set a configuration property, if not already set.""" + if self.get(key) is None: + self.set(key, value) + return self + def setMaster(self, value): """Set master URL to connect to.""" self._jconf.setMaster(value) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e21be0e10a3f7..024fb881877c9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -101,7 +101,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - + self._conf.setIfMissing("spark.rdd.compress", "true") # Set any parameters passed directly to us on the conf if master: self._conf.setMaster(master) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 94ba22306afbd..113a082e16721 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler +from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable +from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ + get_used_memory from py4j.java_collections import ListConverter, MapConverter @@ -197,6 +199,22 @@ def _replaceRoot(self, value): self._sink(1) +def _parse_memory(s): + """ + Parse a memory string in the format supported by Java (e.g. 1g, 200m) and + return the value in MB + + >>> _parse_memory("256m") + 256 + >>> _parse_memory("2g") + 2048 + """ + units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} + if s[-1] not in units: + raise ValueError("invalid format: " + s) + return int(float(s[:-1]) * units[s[-1].lower()]) + + class RDD(object): """ @@ -231,10 +249,10 @@ def context(self): def cache(self): """ - Persist this RDD with the default storage level (C{MEMORY_ONLY}). + Persist this RDD with the default storage level (C{MEMORY_ONLY_SER}). """ self.is_cached = True - self._jrdd.cache() + self.persist(StorageLevel.MEMORY_ONLY_SER) return self def persist(self, storageLevel): @@ -393,7 +411,7 @@ def sample(self, withReplacement, fraction, seed=None): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ - assert fraction >= 0.0, "Invalid fraction value: %s" % fraction + assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -1207,20 +1225,49 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): if numPartitions is None: numPartitions = self._defaultReducePartitions() - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numPartitions) objects - # to Java. Each object is a (splitNumber, [objects]) pair. + # Transferring O(n) objects to Java is too expensive. + # Instead, we'll form the hash buckets in Python, + # transferring O(numPartitions) objects to Java. + # Each object is a (splitNumber, [objects]) pair. + # In order to avoid too huge objects, the objects are + # grouped into chunks. outputSerializer = self.ctx._unbatched_serializer + limit = (_parse_memory(self.ctx._conf.get( + "spark.python.worker.memory", "512m")) / 2) + def add_shuffle_key(split, iterator): buckets = defaultdict(list) + c, batch = 0, min(10 * numPartitions, 1000) for (k, v) in iterator: buckets[partitionFunc(k) % numPartitions].append((k, v)) + c += 1 + + # check used memory and avg size of chunk of objects + if (c % 1000 == 0 and get_used_memory() > limit + or c > batch): + n, size = len(buckets), 0 + for split in buckets.keys(): + yield pack_long(split) + d = outputSerializer.dumps(buckets[split]) + del buckets[split] + yield d + size += len(d) + + avg = (size / n) >> 20 + # let 1M < avg < 10M + if avg < 1: + batch *= 1.5 + elif avg > 10: + batch = max(batch / 1.5, 1) + c = 0 + for (split, items) in buckets.iteritems(): yield pack_long(split) yield outputSerializer.dumps(items) + keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True with _JavaStackTrace(self.context) as st: @@ -1230,8 +1277,8 @@ def add_shuffle_key(split, iterator): id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) - # This is required so that id(partitionFunc) remains unique, even if - # partitionFunc is a lambda: + # This is required so that id(partitionFunc) remains unique, + # even if partitionFunc is a lambda: rdd._partitionFunc = partitionFunc return rdd @@ -1265,26 +1312,28 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, if numPartitions is None: numPartitions = self._defaultReducePartitions() + serializer = self.ctx.serializer + spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() + == 'true') + memory = _parse_memory(self.ctx._conf.get( + "spark.python.worker.memory", "512m")) + agg = Aggregator(createCombiner, mergeValue, mergeCombiners) + def combineLocally(iterator): - combiners = {} - for x in iterator: - (k, v) = x - if k not in combiners: - combiners[k] = createCombiner(v) - else: - combiners[k] = mergeValue(combiners[k], v) - return combiners.iteritems() + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeValues(iterator) + return merger.iteritems() + locally_combined = self.mapPartitions(combineLocally) shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - combiners = {} - for (k, v) in iterator: - if k not in combiners: - combiners[k] = v - else: - combiners[k] = mergeCombiners(combiners[k], v) - return combiners.iteritems() + merger = ExternalMerger(agg, memory, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(iterator) + return merger.iteritems() + return shuffled.mapPartitions(_mergeCombiners) def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): @@ -1343,7 +1392,8 @@ def mergeValue(xs, x): return xs def mergeCombiners(a, b): - return a + b + a.extend(b) + return a return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions).mapValues(lambda x: ResultIterable(x)) @@ -1406,6 +1456,27 @@ def cogroup(self, other, numPartitions=None): """ return python_cogroup((self, other), numPartitions) + def sampleByKey(self, withReplacement, fractions, seed=None): + """ + Return a subset of this RDD sampled by key (via stratified sampling). + Create a sample of this RDD using variable sampling rates for + different keys as specified by fractions, a key to sampling rate map. + + >>> fractions = {"a": 0.2, "b": 0.1} + >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000))) + >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect()) + >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150 + True + >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0 + True + >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0 + True + """ + for fraction in fractions.values(): + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + return self.mapPartitionsWithIndex( \ + RDDStratifiedSampler(withReplacement, fractions, seed).func, True) + def subtractByKey(self, other, numPartitions=None): """ Return each (key, value) pair in C{self} that has no pair with matching diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 7ff1c316c7623..2df000fdb08ca 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -19,8 +19,8 @@ import random -class RDDSampler(object): - def __init__(self, withReplacement, fraction, seed=None): +class RDDSamplerBase(object): + def __init__(self, withReplacement, seed=None): try: import numpy self._use_numpy = True @@ -32,7 +32,6 @@ def __init__(self, withReplacement, fraction, seed=None): self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement - self._fraction = fraction self._random = None self._split = None self._rand_initialized = False @@ -94,6 +93,12 @@ def shuffle(self, vals): else: self._random.shuffle(vals, self._random.random) + +class RDDSampler(RDDSamplerBase): + def __init__(self, withReplacement, fraction, seed=None): + RDDSamplerBase.__init__(self, withReplacement, seed) + self._fraction = fraction + def func(self, split, iterator): if self._withReplacement: for obj in iterator: @@ -107,3 +112,22 @@ def func(self, split, iterator): for obj in iterator: if self.getUniformSample(split) <= self._fraction: yield obj + +class RDDStratifiedSampler(RDDSamplerBase): + def __init__(self, withReplacement, fractions, seed=None): + RDDSamplerBase.__init__(self, withReplacement, seed) + self._fractions = fractions + + def func(self, split, iterator): + if self._withReplacement: + for key, val in iterator: + # For large datasets, the expected number of occurrences of each element in + # a sample with replacement is Poisson(frac). We use that to get a count for + # each element. + count = self.getPoissonSample(split, mean=self._fractions[key]) + for _ in range(0, count): + yield key, val + else: + for key, val in iterator: + if self.getUniformSample(split) <= self._fractions[key]: + yield key, val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9be78b39fbc21..03b31ae9624c2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -193,7 +193,7 @@ def load_stream(self, stream): return chain.from_iterable(self._load_stream_without_unbatching(stream)) def _load_stream_without_unbatching(self, stream): - return self.serializer.load_stream(stream) + return self.serializer.load_stream(stream) def __eq__(self, other): return (isinstance(other, BatchedSerializer) and @@ -302,6 +302,33 @@ class MarshalSerializer(FramedSerializer): loads = marshal.loads +class AutoSerializer(FramedSerializer): + """ + Choose marshal or cPickle as serialization protocol autumatically + """ + def __init__(self): + FramedSerializer.__init__(self) + self._type = None + + def dumps(self, obj): + if self._type is not None: + return 'P' + cPickle.dumps(obj, -1) + try: + return 'M' + marshal.dumps(obj) + except Exception: + self._type = 'P' + return 'P' + cPickle.dumps(obj, -1) + + def loads(self, obj): + _type = obj[0] + if _type == 'M': + return marshal.loads(obj[1:]) + elif _type == 'P': + return cPickle.loads(obj[1:]) + else: + raise ValueError("invalid sevialization type: %s" % _type) + + class UTF8Deserializer(Serializer): """ Deserializes streams written by String.getBytes. diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py new file mode 100644 index 0000000000000..e3923d1c36c57 --- /dev/null +++ b/python/pyspark/shuffle.py @@ -0,0 +1,439 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import platform +import shutil +import warnings +import gc + +from pyspark.serializers import BatchedSerializer, PickleSerializer + +try: + import psutil + + def get_used_memory(): + """ Return the used memory in MB """ + process = psutil.Process(os.getpid()) + if hasattr(process, "memory_info"): + info = process.memory_info() + else: + info = process.get_memory_info() + return info.rss >> 20 +except ImportError: + + def get_used_memory(): + """ Return the used memory in MB """ + if platform.system() == 'Linux': + for line in open('/proc/self/status'): + if line.startswith('VmRSS:'): + return int(line.split()[1]) >> 10 + else: + warnings.warn("Please install psutil to have better " + "support with spilling") + if platform.system() == "Darwin": + import resource + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + return rss >> 20 + # TODO: support windows + return 0 + + +class Aggregator(object): + + """ + Aggregator has tree functions to merge values into combiner. + + createCombiner: (value) -> combiner + mergeValue: (combine, value) -> combiner + mergeCombiners: (combiner, combiner) -> combiner + """ + + def __init__(self, createCombiner, mergeValue, mergeCombiners): + self.createCombiner = createCombiner + self.mergeValue = mergeValue + self.mergeCombiners = mergeCombiners + + +class SimpleAggregator(Aggregator): + + """ + SimpleAggregator is useful for the cases that combiners have + same type with values + """ + + def __init__(self, combiner): + Aggregator.__init__(self, lambda x: x, combiner, combiner) + + +class Merger(object): + + """ + Merge shuffled data together by aggregator + """ + + def __init__(self, aggregator): + self.agg = aggregator + + def mergeValues(self, iterator): + """ Combine the items by creator and combiner """ + raise NotImplementedError + + def mergeCombiners(self, iterator): + """ Merge the combined items by mergeCombiner """ + raise NotImplementedError + + def iteritems(self): + """ Return the merged items ad iterator """ + raise NotImplementedError + + +class InMemoryMerger(Merger): + + """ + In memory merger based on in-memory dict. + """ + + def __init__(self, aggregator): + Merger.__init__(self, aggregator) + self.data = {} + + def mergeValues(self, iterator): + """ Combine the items by creator and combiner """ + # speed up attributes lookup + d, creator = self.data, self.agg.createCombiner + comb = self.agg.mergeValue + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else creator(v) + + def mergeCombiners(self, iterator): + """ Merge the combined items by mergeCombiner """ + # speed up attributes lookup + d, comb = self.data, self.agg.mergeCombiners + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else v + + def iteritems(self): + """ Return the merged items ad iterator """ + return self.data.iteritems() + + +class ExternalMerger(Merger): + + """ + External merger will dump the aggregated data into disks when + memory usage goes above the limit, then merge them together. + + This class works as follows: + + - It repeatedly combine the items and save them in one dict in + memory. + + - When the used memory goes above memory limit, it will split + the combined data into partitions by hash code, dump them + into disk, one file per partition. + + - Then it goes through the rest of the iterator, combine items + into different dict by hash. Until the used memory goes over + memory limit, it dump all the dicts into disks, one file per + dict. Repeat this again until combine all the items. + + - Before return any items, it will load each partition and + combine them seperately. Yield them before loading next + partition. + + - During loading a partition, if the memory goes over limit, + it will partition the loaded data and dump them into disks + and load them partition by partition again. + + `data` and `pdata` are used to hold the merged items in memory. + At first, all the data are merged into `data`. Once the used + memory goes over limit, the items in `data` are dumped indo + disks, `data` will be cleared, all rest of items will be merged + into `pdata` and then dumped into disks. Before returning, all + the items in `pdata` will be dumped into disks. + + Finally, if any items were spilled into disks, each partition + will be merged into `data` and be yielded, then cleared. + + >>> agg = SimpleAggregator(lambda x, y: x + y) + >>> merger = ExternalMerger(agg, 10) + >>> N = 10000 + >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10) + >>> assert merger.spills > 0 + >>> sum(v for k,v in merger.iteritems()) + 499950000 + + >>> merger = ExternalMerger(agg, 10) + >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10) + >>> assert merger.spills > 0 + >>> sum(v for k,v in merger.iteritems()) + 499950000 + """ + + # the max total partitions created recursively + MAX_TOTAL_PARTITIONS = 4096 + + def __init__(self, aggregator, memory_limit=512, serializer=None, + localdirs=None, scale=1, partitions=59, batch=1000): + Merger.__init__(self, aggregator) + self.memory_limit = memory_limit + # default serializer is only used for tests + self.serializer = serializer or \ + BatchedSerializer(PickleSerializer(), 1024) + self.localdirs = localdirs or self._get_dirs() + # number of partitions when spill data into disks + self.partitions = partitions + # check the memory after # of items merged + self.batch = batch + # scale is used to scale down the hash of key for recursive hash map + self.scale = scale + # unpartitioned merged data + self.data = {} + # partitioned merged data, list of dicts + self.pdata = [] + # number of chunks dumped into disks + self.spills = 0 + # randomize the hash of key, id(o) is the address of o (aligned by 8) + self._seed = id(self) + 7 + + def _get_dirs(self): + """ Get all the directories """ + path = os.environ.get("SPARK_LOCAL_DIR", "/tmp") + dirs = path.split(",") + return [os.path.join(d, "python", str(os.getpid()), str(id(self))) + for d in dirs] + + def _get_spill_dir(self, n): + """ Choose one directory for spill by number n """ + return os.path.join(self.localdirs[n % len(self.localdirs)], str(n)) + + def _next_limit(self): + """ + Return the next memory limit. If the memory is not released + after spilling, it will dump the data only when the used memory + starts to increase. + """ + return max(self.memory_limit, get_used_memory() * 1.05) + + def mergeValues(self, iterator): + """ Combine the items by creator and combiner """ + iterator = iter(iterator) + # speedup attribute lookup + creator, comb = self.agg.createCombiner, self.agg.mergeValue + d, c, batch = self.data, 0, self.batch + + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else creator(v) + + c += 1 + if c % batch == 0 and get_used_memory() > self.memory_limit: + self._spill() + self._partitioned_mergeValues(iterator, self._next_limit()) + break + + def _partition(self, key): + """ Return the partition for key """ + return hash((key, self._seed)) % self.partitions + + def _partitioned_mergeValues(self, iterator, limit=0): + """ Partition the items by key, then combine them """ + # speedup attribute lookup + creator, comb = self.agg.createCombiner, self.agg.mergeValue + c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch + + for k, v in iterator: + d = pdata[hfun(k)] + d[k] = comb(d[k], v) if k in d else creator(v) + if not limit: + continue + + c += 1 + if c % batch == 0 and get_used_memory() > limit: + self._spill() + limit = self._next_limit() + + def mergeCombiners(self, iterator, check=True): + """ Merge (K,V) pair by mergeCombiner """ + iterator = iter(iterator) + # speedup attribute lookup + d, comb, batch = self.data, self.agg.mergeCombiners, self.batch + c = 0 + for k, v in iterator: + d[k] = comb(d[k], v) if k in d else v + if not check: + continue + + c += 1 + if c % batch == 0 and get_used_memory() > self.memory_limit: + self._spill() + self._partitioned_mergeCombiners(iterator, self._next_limit()) + break + + def _partitioned_mergeCombiners(self, iterator, limit=0): + """ Partition the items by key, then merge them """ + comb, pdata = self.agg.mergeCombiners, self.pdata + c, hfun = 0, self._partition + for k, v in iterator: + d = pdata[hfun(k)] + d[k] = comb(d[k], v) if k in d else v + if not limit: + continue + + c += 1 + if c % self.batch == 0 and get_used_memory() > limit: + self._spill() + limit = self._next_limit() + + def _spill(self): + """ + dump already partitioned data into disks. + + It will dump the data in batch for better performance. + """ + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + if not self.pdata: + # The data has not been partitioned, it will iterator the + # dataset once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + for k, v in self.data.iteritems(): + h = self._partition(k) + # put one item in batch, make it compatitable with load_stream + # it will increase the memory if dump them in batch + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + s.close() + + self.data.clear() + self.pdata = [{} for i in range(self.partitions)] + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + + self.spills += 1 + gc.collect() # release the memory as much as possible + + def iteritems(self): + """ Return all merged items as iterator """ + if not self.pdata and not self.spills: + return self.data.iteritems() + return self._external_items() + + def _external_items(self): + """ Return all partitioned items as iterator """ + assert not self.data + if any(self.pdata): + self._spill() + hard_limit = self._next_limit() + + try: + for i in range(self.partitions): + self.data = {} + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(i)) + # do not check memory during merging + self.mergeCombiners(self.serializer.load_stream(open(p)), + False) + + # limit the total partitions + if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS + and j < self.spills - 1 + and get_used_memory() > hard_limit): + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible + for v in self._recursive_merged_items(i): + yield v + return + + for v in self.data.iteritems(): + yield v + self.data.clear() + gc.collect() + + # remove the merged partition + for j in range(self.spills): + path = self._get_spill_dir(j) + os.remove(os.path.join(path, str(i))) + + finally: + self._cleanup() + + def _cleanup(self): + """ Clean up all the files in disks """ + for d in self.localdirs: + shutil.rmtree(d, True) + + def _recursive_merged_items(self, start): + """ + merge the partitioned items and return the as iterator + + If one partition can not be fit in memory, then them will be + partitioned and merged recursively. + """ + # make sure all the data are dumps into disks. + assert not self.data + if any(self.pdata): + self._spill() + assert self.spills > 0 + + for i in range(start, self.partitions): + subdirs = [os.path.join(d, "parts", str(i)) + for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, + subdirs, self.scale * self.partitions) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(i)) + m._partitioned_mergeCombiners( + self.serializer.load_stream(open(p))) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() + + for v in m._external_items(): + yield v + + # remove the merged partition + for j in range(self.spills): + path = self._get_spill_dir(j) + os.remove(os.path.join(path, str(i))) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9c5ecd0bb02ab..a92abbf371f18 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,6 +34,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int +from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False try: @@ -47,6 +48,62 @@ SPARK_HOME = os.environ["SPARK_HOME"] +class TestMerger(unittest.TestCase): + + def setUp(self): + self.N = 1 << 16 + self.l = [i for i in xrange(self.N)] + self.data = zip(self.l, self.l) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) + + def test_in_memory(self): + m = InMemoryMerger(self.agg) + m.mergeValues(self.data) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + m = InMemoryMerger(self.agg) + m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + def test_small_dataset(self): + m = ExternalMerger(self.agg, 1000) + m.mergeValues(self.data) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 1000) + m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + self.assertEqual(m.spills, 0) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + def test_medium_dataset(self): + m = ExternalMerger(self.agg, 10) + m.mergeValues(self.data) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N))) + + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + sum(xrange(self.N)) * 3) + + def test_huge_dataset(self): + m = ExternalMerger(self.agg, 10) + m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) + self.assertTrue(m.spills >= 1) + self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.N * 10) + m._cleanup() + + class PySparkTestCase(unittest.TestCase): def setUp(self): diff --git a/python/run-tests b/python/run-tests index 9282aa47e8375..29f755fc0dcd3 100755 --- a/python/run-tests +++ b/python/run-tests @@ -61,6 +61,7 @@ run_test "pyspark/broadcast.py" run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" unset PYSPARK_DOC_TEST +run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh new file mode 100755 index 0000000000000..8398e6f19b511 --- /dev/null +++ b/sbin/start-thriftserver.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL Thrift server + +# Enter posix mode for bash +set -o posix + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-thriftserver [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +fi + +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" +exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6decde3fcd62d..531bfddbf237b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -32,7 +32,7 @@ Spark Project Catalyst http://spark.apache.org/ - catalyst + catalyst diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c7188469bfb86..02bdb64f308a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ - /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing * when all relations are already filled in and the analyser needs only to resolve attribute @@ -54,6 +53,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool StarExpansion :: ResolveFunctions :: GlobalAggregates :: + UnresolvedHavingClauseAttributes :: typeCoercionRules :_*), Batch("Check Analysis", Once, CheckResolution), @@ -151,6 +151,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * This rule finds expressions in HAVING clause filters that depend on + * unresolved attributes. It pushes these expressions down to the underlying + * aggregates and then projects them away above the filter. + */ + object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) + if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => { + val evaluatedCondition = Alias(havingCondition, "havingCondition")() + val aggExprsWithHaving = evaluatedCondition +: originalAggExprs + + Project(aggregate.output, + Filter(evaluatedCondition.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } + + } + + protected def containsAggregate(condition: Expression): Boolean = + condition + .collect { case ae: AggregateExpression => ae } + .nonEmpty + } + /** * When a SELECT clause has only a single expression and that expression is a * [[catalyst.expressions.Generator Generator]] we convert the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 76ddeba9cb312..67a8ce9b88c3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -231,11 +231,23 @@ trait HiveTypeCoercion { * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ object BooleanComparisons extends Rule[LogicalPlan] { + val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_)) + val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_)) + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // No need to change EqualTo operators as that actually makes sense for boolean types. + + // Hive treats (true = 1) as true and (false = 0) as true. + case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l + case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r + case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) + case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) + + // No need to change other EqualTo operators as that actually makes sense for boolean types. case e: EqualTo => e + // No need to change the EqualNullSafe operators, too + case e: EqualNullSafe => e // Otherwise turn them to Byte types so that there exists and ordering. case p: BinaryComparison if p.left.dataType == BooleanType && p.right.dataType == BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1b503b957d146..5c8c810d9135a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -77,10 +77,27 @@ package object dsl { def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) def === (other: Expression) = EqualTo(expr, other) + def <=> (other: Expression) = EqualNullSafe(expr, other) def !== (other: Expression) = Not(EqualTo(expr, other)) + def in(list: Expression*) = In(expr, list) + def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) + def contains(other: Expression) = Contains(expr, other) + def startsWith(other: Expression) = StartsWith(expr, other) + def endsWith(other: Expression) = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + Substring(expr, pos, len) + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + Substring(expr, pos, len) + + def isNull = IsNull(expr) + def isNotNull = IsNotNull(expr) + + def getItem(ordinal: Expression) = GetItem(expr, ordinal) + def getField(fieldName: String) = GetField(expr, fieldName) + def cast(to: DataType) = Cast(expr, to) def asc = SortOrder(expr, Ascending) @@ -112,6 +129,7 @@ package object dsl { def sumDistinct(e: Expression) = SumDistinct(e) def count(e: Expression) = Count(e) def countDistinct(e: Expression*) = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) def avg(e: Expression) = Average(e) def first(e: Expression) = First(e) def min(e: Expression) = Min(e) @@ -163,6 +181,18 @@ package object dsl { /** Creates a new AttributeReference of type binary */ def binary = AttributeReference(s, BinaryType, nullable = true)() + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): AttributeReference = + map(MapType(keyType, valueType)) + def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) + def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b63406b94a4a3..06b94a98d3cd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -153,6 +153,22 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } +case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "<=>" + override def nullable = false + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false + } else { + l == r + } + } +} + case class LessThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<" override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c65987b7120b2..5f86d6047cb9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,6 +153,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) + case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) + case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) // For Coalesce, remove null literals. case e @ Coalesce(children) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 1d5f033f0d274..a357c6ffb8977 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -43,8 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(0, AttributeReference("key", StringType, nullable = false)()), - BoundReference(1, AttributeReference("value", StringType, nullable = false)())) + BoundReference(1, AttributeReference("", StringType, nullable = false)())) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index db1ae29d400c6..58f8c341e6676 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite { val c3 = 'a.boolean.at(2) val c4 = 'a.boolean.at(3) - checkEvaluation(IsNull(c1), false, row) - checkEvaluation(IsNotNull(c1), true, row) + checkEvaluation(c1.isNull, false, row) + checkEvaluation(c1.isNotNull, true, row) - checkEvaluation(IsNull(c2), true, row) - checkEvaluation(IsNotNull(c2), false, row) + checkEvaluation(c2.isNull, true, row) + checkEvaluation(c2.isNotNull, false, row) - checkEvaluation(IsNull(Literal(1, ShortType)), false) - checkEvaluation(IsNotNull(Literal(1, ShortType)), true) + checkEvaluation(Literal(1, ShortType).isNull, false) + checkEvaluation(Literal(1, ShortType).isNotNull, true) - checkEvaluation(IsNull(Literal(null, ShortType)), true) - checkEvaluation(IsNotNull(Literal(null, ShortType)), false) + checkEvaluation(Literal(null, ShortType).isNull, true) + checkEvaluation(Literal(null, ShortType).isNotNull, false) checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) @@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(If(Literal(false, BooleanType), Literal("a", StringType), Literal("b", StringType)), "b", row) - checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row) - checkEvaluation(In(Literal("^Ba*n", StringType), - Literal("^Ba*n", StringType) :: Nil), true, row) - checkEvaluation(In(Literal("^Ba*n", StringType), - Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) + checkEvaluation(c1 in (c1, c2), true, row) + checkEvaluation( + Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row) + checkEvaluation( + Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row) } test("case when") { @@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite { assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) + + checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row) + checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row) + checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row) } test("arithmetic") { @@ -447,11 +451,13 @@ class ExpressionEvaluationSuite extends FunSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) val c4 = 'a.int.at(3) + val c5 = 'a.int.at(4) + val c6 = 'a.int.at(5) checkEvaluation(LessThan(c1, c4), null, row) checkEvaluation(LessThan(c1, c2), true, row) @@ -465,6 +471,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 >= c2, false, row) checkEvaluation(c1 === c2, false, row) checkEvaluation(c1 !== c2, true, row) + checkEvaluation(c4 <=> c1, false, row) + checkEvaluation(c1 <=> c4, false, row) + checkEvaluation(c4 <=> c6, true, row) + checkEvaluation(c3 <=> c5, true, row) + checkEvaluation(Literal(true) <=> Literal(null, BooleanType), false, row) + checkEvaluation(Literal(null, BooleanType) <=> Literal(true), false, row) } test("StringComparison") { @@ -472,20 +484,20 @@ class ExpressionEvaluationSuite extends FunSuite { val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) - checkEvaluation(Contains(c1, "b"), true, row) - checkEvaluation(Contains(c1, "x"), false, row) - checkEvaluation(Contains(c2, "b"), null, row) - checkEvaluation(Contains(c1, Literal(null, StringType)), null, row) + checkEvaluation(c1 contains "b", true, row) + checkEvaluation(c1 contains "x", false, row) + checkEvaluation(c2 contains "b", null, row) + checkEvaluation(c1 contains Literal(null, StringType), null, row) - checkEvaluation(StartsWith(c1, "a"), true, row) - checkEvaluation(StartsWith(c1, "b"), false, row) - checkEvaluation(StartsWith(c2, "a"), null, row) - checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row) + checkEvaluation(c1 startsWith "a", true, row) + checkEvaluation(c1 startsWith "b", false, row) + checkEvaluation(c2 startsWith "a", null, row) + checkEvaluation(c1 startsWith Literal(null, StringType), null, row) - checkEvaluation(EndsWith(c1, "c"), true, row) - checkEvaluation(EndsWith(c1, "b"), false, row) - checkEvaluation(EndsWith(c2, "b"), null, row) - checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row) + checkEvaluation(c1 endsWith "c", true, row) + checkEvaluation(c1 endsWith "b", false, row) + checkEvaluation(c2 endsWith "b", null, row) + checkEvaluation(c1 endsWith Literal(null, StringType), null, row) } test("Substring") { @@ -542,5 +554,10 @@ class ExpressionEvaluationSuite extends FunSuite { assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + + checkEvaluation(s.substr(0, 2), "ex", row) + checkEvaluation(s.substr(0), "example", row) + checkEvaluation(s.substring(0, 2), "ex", row) + checkEvaluation(s.substring(0), "example", row) } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c309c43804d97..3a038a2db6173 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -32,7 +32,7 @@ Spark Project SQL http://spark.apache.org/ - sql + sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2b787e14f3f15..41920c00b5a2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,12 +30,13 @@ import scala.collection.JavaConverters._ * SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads). */ trait SQLConf { + import SQLConf._ /** ************************ Spark SQL Params/Hints ******************* */ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -43,11 +44,10 @@ trait SQLConf { * effectively disables auto conversion. * Hive setting: hive.auto.convert.join.noconditionaltask.size. */ - private[spark] def autoConvertJoinSize: Int = - get("spark.sql.auto.convert.join.size", "10000").toInt + private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt /** A comma-separated list of table names marked to be broadcasted during joins. */ - private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "") + private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "") /** ********************** SQLConf functionality methods ************ */ @@ -61,7 +61,7 @@ trait SQLConf { def set(key: String, value: String): Unit = { require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for ${key}") + require(value != null, s"value cannot be null for $key") settings.put(key, value) } @@ -90,3 +90,13 @@ trait SQLConf { } } + +object SQLConf { + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 34b355e906695..34654447a5f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -24,10 +24,10 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} -import com.twitter.chill.AllScalaRegistrar +import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -48,22 +48,41 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } } -private[sql] object SparkSqlSerializer { - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - @transient lazy val ser: KryoSerializer = { +private[execution] class KryoResourcePool(size: Int) + extends ResourcePool[SerializerInstance](size) { + + val ser: KryoSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + // TODO (lian) Using KryoSerializer here is workaround, needs further investigation + // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization + // related error. new KryoSerializer(sparkConf) } - def serialize[T: ClassTag](o: T): Array[Byte] = { - ser.newInstance().serialize(o).array() - } + def newInstance() = ser.newInstance() +} - def deserialize[T: ClassTag](bytes: Array[Byte]): T = { - ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) +private[sql] object SparkSqlSerializer { + @transient lazy val resourcePool = new KryoResourcePool(30) + + private[this] def acquireRelease[O](fn: SerializerInstance => O): O = { + val kryo = resourcePool.borrow + try { + fn(kryo) + } finally { + resourcePool.release(kryo) + } } + + def serialize[T: ClassTag](o: T): Array[Byte] = + acquireRelease { k => + k.serialize(o).array() + } + + def deserialize[T: ClassTag](bytes: Array[Byte]): T = + acquireRelease { k => + k.deserialize[T](ByteBuffer.wrap(bytes)) + } } private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 98d2f89c8ae71..9293239131d52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { /** @@ -44,28 +45,53 @@ trait Command { case class SetCommand( key: Option[String], value: Option[String], output: Seq[Attribute])( @transient context: SQLContext) - extends LeafNode with Command { + extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => - context.set(k, v) - Array(k -> v) + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") + context.set(SQLConf.SHUFFLE_PARTITIONS, v) + Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + } else { + context.set(k, v) + Array(s"$k=$v") + } // Query the value bound to key k. case (Some(k), _) => - Array(k -> context.getOption(k).getOrElse("")) + // TODO (lian) This is just a workaround to make the Simba ODBC driver work. + // Should remove this once we get the ODBC driver updated. + if (k == "-v") { + val hiveJars = Seq( + "hive-exec-0.12.0.jar", + "hive-service-0.12.0.jar", + "hive-common-0.12.0.jar", + "hive-hwi-0.12.0.jar", + "hive-0.12.0.jar").mkString(":") + + Array( + "system:java.class.path=" + hiveJars, + "system:sun.java.command=shark.SharkServer2") + } + else { + Array(s"$k=${context.getOption(k).getOrElse("")}") + } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll + context.getAll.map { case (k, v) => + s"$k=$v" + } case _ => throw new IllegalArgumentException() } def execute(): RDD[Row] = { - val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) } + val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } context.sparkContext.parallelize(rows, 1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index df80dfb98b93c..b48c70ee73a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.json -import scala.collection.JavaConversions._ +import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import com.fasterxml.jackson.databind.ObjectMapper @@ -210,12 +210,12 @@ private[sql] object JsonRDD extends Logging { case (k, dataType) => (s"$key.$k", dataType) } ++ Set((key, StructType(Nil))) } - case (key: String, array: List[_]) => { + case (key: String, array: Seq[_]) => { // The value associated with the key is an array. typeOfArray(array) match { case ArrayType(StructType(Nil)) => { // The elements of this arrays are structs. - array.asInstanceOf[List[Map[String, Any]]].flatMap { + array.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) @@ -229,7 +229,7 @@ private[sql] object JsonRDD extends Logging { } /** - * Converts a Java Map/List to a Scala Map/List. + * Converts a Java Map/List to a Scala Map/Seq. * We do not use Jackson's scala module at here because * DefaultScalaModule in jackson-module-scala will make * the parsing very slow. @@ -239,9 +239,9 @@ private[sql] object JsonRDD extends Logging { // .map(identity) is used as a workaround of non-serializable Map // generated by .mapValues. // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 - map.toMap.mapValues(scalafy).map(identity) + JMapWrapper(map).mapValues(scalafy).map(identity) case list: java.util.List[_] => - list.toList.map(scalafy) + JListWrapper(list).map(scalafy) case atom => atom } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index c8ea01c4e1b6a..1a6a6c17473a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test._ /* Implicits */ @@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupBy('a)('a, Sum('b)), + testData2.groupBy('a)('a, sum('b)), Seq((1,3),(2,3),(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)), + testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), 9 ) checkAnswer( - testData2.aggregate(Sum('b)), + testData2.aggregate(sum('b)), 9 ) } @@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest { Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) checkAnswer( - arrayData.orderBy(GetItem('data, 0).asc), + arrayData.orderBy('data.getItem(0).asc), arrayData.collect().sortBy(_.data(0)).toSeq) checkAnswer( - arrayData.orderBy(GetItem('data, 0).desc), + arrayData.orderBy('data.getItem(0).desc), arrayData.collect().sortBy(_.data(0)).reverse.toSeq) checkAnswer( - mapData.orderBy(GetItem('data, 1).asc), + mapData.orderBy('data.getItem(1).asc), mapData.collect().sortBy(_.data(1)).toSeq) checkAnswer( - mapData.orderBy(GetItem('data, 1).desc), + mapData.orderBy('data.getItem(1).desc), mapData.collect().sortBy(_.data(1)).reverse.toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 08293f7f0ca30..1a58d73d9e7f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest { assert(get(testKey, testVal + "_") == testVal) assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - sql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - sql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set some.property=20") + assert(get("some.property", "0") == "20") + sql("set some.property = 40") + assert(get("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" @@ -70,4 +70,9 @@ class SQLConfSuite extends QueryTest { clear() } + test("deprecated property") { + clear() + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6736189c96d4b..de9e8aa4f62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -424,25 +424,25 @@ class SQLQuerySuite extends QueryTest { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Seq(testKey, testVal), - Seq(testKey + testKey, testVal + testVal)) + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Seq(Seq(testKey, testVal)) + Seq(Seq(s"$testKey=$testVal")) ) checkAnswer( sql(s"SET $nonexistentKey"), - Seq(Seq(nonexistentKey, "")) + Seq(Seq(s"$nonexistentKey=")) ) clear() } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml new file mode 100644 index 0000000000000..7fac90fdc596d --- /dev/null +++ b/sql/hive-thriftserver/pom.xml @@ -0,0 +1,82 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-hive-thriftserver_2.10 + jar + Spark Project Hive + http://spark.apache.org/ + + hive-thriftserver + + + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + + + org.spark-project.hive + hive-cli + ${hive.version} + + + org.spark-project.hive + hive-jdbc + ${hive.version} + + + org.spark-project.hive + hive-beeline + ${hive.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala new file mode 100644 index 0000000000000..ddbc2a79fb512 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService +import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a + * `HiveThriftServer2` thrift server. + */ +private[hive] object HiveThriftServer2 extends Logging { + var LOG = LogFactory.getLog(classOf[HiveServer2]) + + def main(args: Array[String]) { + val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + + if (!optionsProcessor.process(args)) { + logger.warn("Error starting HiveThriftServer2 with given arguments") + System.exit(-1) + } + + val ss = new SessionState(new HiveConf(classOf[SessionState])) + + // Set all properties specified via command line. + val hiveConf: HiveConf = ss.getConf + hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => + logger.debug(s"HiveConf var: $k=$v") + } + + SessionState.start(ss) + + logger.info("Starting SparkContext") + SparkSQLEnv.init() + SessionState.start(ss) + + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.sparkContext.stop() + } + } + ) + + try { + val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) + server.init(hiveConf) + server.start() + logger.info("HiveThriftServer2 started") + } catch { + case e: Exception => + logger.error("Error starting HiveThriftServer2", e) + System.exit(-1) + } + } +} + +private[hive] class HiveThriftServer2(hiveContext: HiveContext) + extends HiveServer2 + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + setSuperField(this, "cliService", sparkSqlCliService) + addService(sparkSqlCliService) + + val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService) + setSuperField(this, "thriftCLIService", thriftCliService) + addService(thriftCliService) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala new file mode 100644 index 0000000000000..599294dfbb7d7 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] object ReflectionUtils { + def setSuperField(obj : Object, fieldName: String, fieldValue: Object) { + setAncestorField(obj, 1, fieldName, fieldValue) + } + + def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) { + val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.set(obj, fieldValue) + } + + def getSuperField[T](obj: AnyRef, fieldName: String): T = { + getAncestorField[T](obj, 1, fieldName) + } + + def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = { + val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next() + val field = ancestor.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(clazz).asInstanceOf[T] + } + + def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = { + invoke(clazz, null, methodName, args: _*) + } + + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala new file mode 100755 index 0000000000000..27268ecb923e9 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io._ +import java.util.{ArrayList => JArrayList} + +import jline.{ConsoleReader, History} +import org.apache.commons.lang.StringUtils +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.thrift.transport.TSocket + +import org.apache.spark.sql.Logging + +private[hive] object SparkSQLCLIDriver { + private var prompt = "spark-sql" + private var continuedPrompt = "".padTo(prompt.length, ' ') + private var transport:TSocket = _ + + installSignalHandler() + + /** + * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(), + * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while + * a command is being processed by the current thread. + */ + def installSignalHandler() { + HiveInterruptUtils.add(new HiveInterruptCallback { + override def interrupt() { + // Handle remote execution mode + if (SparkSQLEnv.sparkContext != null) { + SparkSQLEnv.sparkContext.cancelAllJobs() + } else { + if (transport != null) { + // Force closing of TCP connection upon session termination + transport.getSocket.close() + } + } + } + }) + } + + def main(args: Array[String]) { + val oproc = new OptionsProcessor() + if (!oproc.process_stage1(args)) { + System.exit(1) + } + + // NOTE: It is critical to do this here so that log4j is reinitialized + // before any of the other core hive classes are loaded + var logInitFailed = false + var logInitDetailMessage: String = null + try { + logInitDetailMessage = LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => + logInitFailed = true + logInitDetailMessage = e.getMessage + } + + val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) + + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + if (!oproc.process_stage2(sessionState)) { + System.exit(2) + } + + if (!sessionState.getIsSilent) { + if (logInitFailed) System.err.println(logInitDetailMessage) + else SessionState.getConsole.printInfo(logInitDetailMessage) + } + + // Set all properties specified via command line. + val conf: HiveConf = sessionState.getConf + sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => + conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + sessionState.getOverriddenConfigurations.put( + item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String]) + } + + SessionState.start(sessionState) + + // Clean up after we exit + Runtime.getRuntime.addShutdownHook( + new Thread() { + override def run() { + SparkSQLEnv.stop() + } + } + ) + + // "-h" option has been passed, so connect to Hive thrift server. + if (sessionState.getHost != null) { + sessionState.connect() + if (sessionState.isRemoteMode) { + prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt + continuedPrompt = "".padTo(prompt.length, ' ') + } + } + + if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + // Hadoop-20 and above - we need to augment classpath using hiveconf + // components. + // See also: code in ExecDriver.java + var loader = conf.getClassLoader + val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS) + if (StringUtils.isNotBlank(auxJars)) { + loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ",")) + } + conf.setClassLoader(loader) + Thread.currentThread().setContextClassLoader(loader) + } + + val cli = new SparkSQLCLIDriver + cli.setHiveVariables(oproc.getHiveVariables) + + // TODO work around for set the log output to console, because the HiveContext + // will set the output into an invalid buffer. + sessionState.in = System.in + try { + sessionState.out = new PrintStream(System.out, true, "UTF-8") + sessionState.info = new PrintStream(System.err, true, "UTF-8") + sessionState.err = new PrintStream(System.err, true, "UTF-8") + } catch { + case e: UnsupportedEncodingException => System.exit(3) + } + + // Execute -i init files (always in silent mode) + cli.processInitFiles(sessionState) + + if (sessionState.execString != null) { + System.exit(cli.processLine(sessionState.execString)) + } + + try { + if (sessionState.fileName != null) { + System.exit(cli.processFile(sessionState.fileName)) + } + } catch { + case e: FileNotFoundException => + System.err.println(s"Could not open input file for reading. (${e.getMessage})") + System.exit(3) + } + + val reader = new ConsoleReader() + reader.setBellEnabled(false) + // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) + CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + + val historyDirectory = System.getProperty("user.home") + + try { + if (new File(historyDirectory).exists()) { + val historyFile = historyDirectory + File.separator + ".hivehistory" + reader.setHistory(new History(new File(historyFile))) + } else { + System.err.println("WARNING: Directory for Hive history file: " + historyDirectory + + " does not exist. History will not be available during this session.") + } + } catch { + case e: Exception => + System.err.println("WARNING: Encountered an error while trying to initialize Hive's " + + "history file. History will not be available during this session.") + System.err.println(e.getMessage) + } + + val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") + clientTransportTSocketField.setAccessible(true) + + transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] + + var ret = 0 + var prefix = "" + val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb", + classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState) + + def promptWithCurrentDB = s"$prompt$currentDB" + def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic( + classOf[CliDriver], "spacesForString", classOf[String] -> currentDB) + + var currentPrompt = promptWithCurrentDB + var line = reader.readLine(currentPrompt + "> ") + + while (line != null) { + if (prefix.nonEmpty) { + prefix += '\n' + } + + if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) { + line = prefix + line + ret = cli.processLine(line, true) + prefix = "" + currentPrompt = promptWithCurrentDB + } else { + prefix = prefix + line + currentPrompt = continuedPromptWithDBSpaces + } + + line = reader.readLine(currentPrompt + "> ") + } + + sessionState.close() + + System.exit(ret) + } +} + +private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { + private val sessionState = SessionState.get().asInstanceOf[CliSessionState] + + private val LOG = LogFactory.getLog("CliDriver") + + private val console = new SessionState.LogHelper(LOG) + + private val conf: Configuration = + if (sessionState != null) sessionState.getConf else new Configuration() + + // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver + // because the Hive unit tests do not go through the main() code path. + if (!sessionState.isRemoteMode) { + SparkSQLEnv.init() + } + + override def processCmd(cmd: String): Int = { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + if (cmd_trimmed.toLowerCase.equals("quit") || + cmd_trimmed.toLowerCase.equals("exit") || + tokens(0).equalsIgnoreCase("source") || + cmd_trimmed.startsWith("!") || + tokens(0).toLowerCase.equals("list") || + sessionState.isRemoteMode) { + val start = System.currentTimeMillis() + super.processCmd(cmd) + val end = System.currentTimeMillis() + val timeTaken: Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds") + 0 + } else { + var ret = 0 + val hconf = conf.asInstanceOf[HiveConf] + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + + if (proc != null) { + if (proc.isInstanceOf[Driver]) { + val driver = new SparkSQLDriver + + driver.init() + val out = sessionState.out + val start:Long = System.currentTimeMillis() + if (sessionState.getIsVerbose) { + out.println(cmd) + } + + ret = driver.run(cmd).getResponseCode + if (ret != 0) { + driver.close() + return ret + } + + val res = new JArrayList[String]() + + if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { + // Print the column names. + Option(driver.getSchema.getFieldSchemas).map { fields => + out.println(fields.map(_.getName).mkString("\t")) + } + } + + try { + while (!out.checkError() && driver.getResults(res)) { + res.foreach(out.println) + res.clear() + } + } catch { + case e:IOException => + console.printError( + s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} + |${org.apache.hadoop.util.StringUtils.stringifyException(e)} + """.stripMargin) + ret = 1 + } + + val cret = driver.close() + if (ret == 0) { + ret = cret + } + + val end = System.currentTimeMillis() + if (end > start) { + val timeTaken:Double = (end - start) / 1000.0 + console.printInfo(s"Time taken: $timeTaken seconds", null) + } + + // Destroy the driver to release all the locks. + driver.destroy() + } else { + if (sessionState.getIsVerbose) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } + ret = proc.run(cmd_1).getResponseCode + } + } + ret + } + } +} + diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala new file mode 100644 index 0000000000000..42cbf363b274f --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.io.IOException +import java.util.{List => JList} +import javax.security.auth.login.LoginException + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hive.service.Service.STATE +import org.apache.hive.service.auth.HiveAuthFactory +import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.{AbstractService, Service, ServiceException} + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +private[hive] class SparkSQLCLIService(hiveContext: HiveContext) + extends CLIService + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + setSuperField(this, "sessionManager", sparkSqlSessionManager) + addService(sparkSqlSessionManager) + + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + val serverUserName = ShimLoader.getHadoopShims + .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) + setSuperField(this, "serverUserName", serverUserName) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } + + initCompositeService(hiveConf) + } +} + +private[thriftserver] trait ReflectedCompositeService { this: AbstractService => + def initCompositeService(hiveConf: HiveConf) { + // Emulating `CompositeService.init(hiveConf)` + val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") + serviceList.foreach(_.init(hiveConf)) + + // Emulating `AbstractService.init(hiveConf)` + invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) + setAncestorField(this, 3, "hiveConf", hiveConf) + invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED) + getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala new file mode 100644 index 0000000000000..5202aa9903e03 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ + +import java.util.{ArrayList => JArrayList} + +import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} + +private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver with Logging { + + private var tableSchema: Schema = _ + private var hiveResponse: Seq[String] = _ + + override def init(): Unit = { + } + + private def getResultSetSchema(query: context.QueryExecution): Schema = { + val analyzed = query.analyzed + logger.debug(s"Result Schema: ${analyzed.output}") + if (analyzed.output.size == 0) { + new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + } else { + val fieldSchemas = analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + + new Schema(fieldSchemas, null) + } + } + + override def run(command: String): CommandProcessorResponse = { + val execution = context.executePlan(context.hql(command).logicalPlan) + + // TODO unify the error code + try { + hiveResponse = execution.stringResult() + tableSchema = getResultSetSchema(execution) + new CommandProcessorResponse(0) + } catch { + case cause: Throwable => + logger.error(s"Failed in [$command]", cause) + new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) + } + } + + override def close(): Int = { + hiveResponse = null + tableSchema = null + 0 + } + + override def getSchema: Schema = tableSchema + + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } + + override def destroy() { + super.destroy() + hiveResponse = null + tableSchema = null + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala new file mode 100644 index 0000000000000..451c3bd7b9352 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.hadoop.hive.ql.session.SessionState + +import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} +import org.apache.spark.sql.Logging +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.{SparkConf, SparkContext} + +/** A singleton object for the master program. The slaves should not access this. */ +private[hive] object SparkSQLEnv extends Logging { + logger.debug("Initializing SparkSQLEnv") + + var hiveContext: HiveContext = _ + var sparkContext: SparkContext = _ + + def init() { + if (hiveContext == null) { + sparkContext = new SparkContext(new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + + sparkContext.addSparkListener(new StatsReportListener()) + + hiveContext = new HiveContext(sparkContext) { + @transient override lazy val sessionState = SessionState.get() + @transient override lazy val hiveconf = sessionState.getConf + } + } + } + + /** Cleans up and shuts down the Spark SQL environments. */ + def stop() { + logger.debug("Shutting down Spark SQL Environment") + // Stop the SparkContext + if (SparkSQLEnv.sparkContext != null) { + sparkContext.stop() + sparkContext = null + hiveContext = null + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 0000000000000..6b3275b4eaf04 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.session.SessionManager + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala new file mode 100644 index 0000000000000..a4e1f3e762e89 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver.server + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.math.{random, round} + +import java.sql.Timestamp +import java.util.{Map => JMap} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} + +/** + * Executes queries using Spark SQL, and maintains a list of handles to active queries. + */ +class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { + val handleToOperation = ReflectionUtils + .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + + override def newExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + async: Boolean): ExecuteStatementOperation = synchronized { + + val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logger.debug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No. + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + dataTypes(curCol) match { + case StringType => + row.addString(sparkRow(curCol).asInstanceOf[String]) + case IntegerType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) + case BooleanType => + row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) + case DoubleType => + row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) + case FloatType => + row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) + case DecimalType => + val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal + row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) + case ByteType => + row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) + case ShortType => + row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) + case TimestampType => + row.addColumnValue( + ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) + row.addColumnValue(ColumnValue.stringValue(hiveString)) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def getResultSetSchema: TableSchema = { + logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logger.info(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.hql(statement) + logger.debug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = result.queryExecution.toRdd.toLocalIterator + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logger.error("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } + } + + handleToOperation.put(operation.getHandle, operation) + operation + } +} diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt new file mode 100644 index 0000000000000..850f8014b6f05 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt @@ -0,0 +1,5 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala new file mode 100644 index 0000000000000..b90670a796b81 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, InputStreamReader, PrintWriter} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.hive.test.TestHive + +class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { + val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") + val METASTORE_PATH = TestUtils.getMetastorePath("cli") + + override def beforeAll() { + val pb = new ProcessBuilder( + "../../bin/spark-sql", + "--master", + "local", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) + + process = pb.start() + outputWriter = new PrintWriter(process.getOutputStream, true) + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "spark-sql>") + } + + override def afterAll() { + process.destroy() + process.waitFor() + } + + test("simple commands") { + val dataFilePath = getDataFile("data/files/small_kv.txt") + executeQuery("create table hive_test1(key int, val string);") + executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;") + executeQuery("cache table hive_test1", "Time taken") + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala new file mode 100644 index 0000000000000..59f4952b78bc6 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import scala.collection.JavaConversions._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent._ + +import java.io.{BufferedReader, InputStreamReader} +import java.sql.{Connection, DriverManager, Statement} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.Logging +import org.apache.spark.sql.catalyst.util.getTempFilePath + +/** + * Test for the HiveThriftServer2 using JDBC. + */ +class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { + + val WAREHOUSE_PATH = getTempFilePath("warehouse") + val METASTORE_PATH = getTempFilePath("metastore") + + val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver" + val TABLE = "test" + // use a different port, than the hive standard 10000, + // for tests to avoid issues with the port being taken on some machines + val PORT = "10000" + + // If verbose is true, the test program will print all outputs coming from the Hive Thrift server. + val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean + + Class.forName(DRIVER_NAME) + + override def beforeAll() { launchServer() } + + override def afterAll() { stopServer() } + + private def launchServer(args: Seq[String] = Seq.empty) { + // Forking a new process to start the Hive Thrift server. The reason to do this is it is + // hard to clean up Hive resources entirely, so we just start a new process and kill + // that process for cleanup. + val defaultArgs = Seq( + "../../sbin/start-thriftserver.sh", + "--master local", + "--hiveconf", + "hive.root.logger=INFO,console", + "--hiveconf", + s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", + "--hiveconf", + s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") + val pb = new ProcessBuilder(defaultArgs ++ args) + process = pb.start() + inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) + errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) + waitForOutput(inputReader, "ThriftBinaryCLIService listening on") + + // Spawn a thread to read the output from the forked process. + // Note that this is necessary since in some configurations, log4j could be blocked + // if its output to stderr are not read, and eventually blocking the entire test suite. + future { + while (true) { + val stdout = readFrom(inputReader) + val stderr = readFrom(errorReader) + if (VERBOSE && stdout.length > 0) { + println(stdout) + } + if (VERBOSE && stderr.length > 0) { + println(stderr) + } + Thread.sleep(50) + } + } + } + + private def stopServer() { + process.destroy() + process.waitFor() + } + + test("test query execution against a Hive Thrift server") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test") + stmt.execute("DROP TABLE IF EXISTS test_cached") + stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") + stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CACHE TABLE test_cached") + + var rs = stmt.executeQuery("select count(*) from test") + rs.next() + assert(rs.getInt(1) === 5) + + rs = stmt.executeQuery("select count(*) from test_cached") + rs.next() + assert(rs.getInt(1) === 4) + + stmt.close() + } + + def getConnection: Connection = { + val connectURI = s"jdbc:hive2://localhost:$PORT/" + DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") + } + + def createStatement(): Statement = getConnection.createStatement() +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala new file mode 100644 index 0000000000000..bb2242618fbef --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.{BufferedReader, PrintWriter} +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.hive.common.LogUtils +import org.apache.hadoop.hive.common.LogUtils.LogInitializationException + +object TestUtils { + val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss") + + def getWarehousePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" + + timestamp.format(new Date) + } + + def getMetastorePath(prefix: String): String = { + System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" + + timestamp.format(new Date) + } + + // Dummy function for initialize the log4j properties. + def init() { } + + // initialize log4j + try { + LogUtils.initHiveLog4j() + } catch { + case e: LogInitializationException => // Ignore the error. + } +} + +trait TestUtils { + var process : Process = null + var outputWriter : PrintWriter = null + var inputReader : BufferedReader = null + var errorReader : BufferedReader = null + + def executeQuery( + cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = { + println("Executing: " + cmd + ", expecting output: " + outputMessage) + outputWriter.write(cmd + "\n") + outputWriter.flush() + waitForQuery(timeout, outputMessage) + } + + protected def waitForQuery(timeout: Long, message: String): String = { + if (waitForOutput(errorReader, message, timeout)) { + Thread.sleep(500) + readOutput() + } else { + assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput()) + null + } + } + + // Wait for the specified str to appear in the output. + protected def waitForOutput( + reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = { + val startTime = System.currentTimeMillis + var out = "" + while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) { + out += readFrom(reader) + } + out.contains(str) + } + + // Read stdout output and filter out garbage collection messages. + protected def readOutput(): String = { + val output = readFrom(inputReader) + // Remove GC Messages + val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC")) + .mkString("\n") + filteredOutput + } + + protected def readFrom(reader: BufferedReader): String = { + var out = "" + var c = 0 + while (reader.ready) { + c = reader.read() + out += c.asInstanceOf[Char] + } + out + } + + protected def getDataFile(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(name) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala similarity index 98% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index fd44325925cdd..c69e93ba2b9ba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -196,7 +196,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive returns the results of describe as plain text. Comments with multiple lines // introduce extra lines in the Hive results, which make the result comparison fail. - "describe_comment_indent" + "describe_comment_indent", + + // Limit clause without a ordering, which causes failure. + "orc_predicate_pushdown" ) /** @@ -291,6 +294,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "correlationoptimizer1", "correlationoptimizer10", "correlationoptimizer11", + "correlationoptimizer13", "correlationoptimizer14", "correlationoptimizer15", "correlationoptimizer2", @@ -299,6 +303,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "correlationoptimizer6", "correlationoptimizer7", "correlationoptimizer8", + "correlationoptimizer9", "count", "cp_mj_rc", "create_insert_outputformat", @@ -389,6 +394,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_sort_8", "groupby_sort_9", "groupby_sort_test_1", + "having", + "having1", "implicit_cast1", "innerjoin", "inoutdriver", @@ -499,6 +506,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_hive_626", "join_map_ppr", "join_nulls", + "join_nullsafe", "join_rc", "join_reorder2", "join_reorder3", @@ -730,6 +738,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_double", "udf_E", "udf_elt", + "udf_equal", "udf_exp", "udf_field", "udf_find_in_set", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index f30ae28b81e06..93d00f7c37c9b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -32,7 +32,7 @@ Spark Project Hive http://spark.apache.org/ - hive + hive @@ -102,6 +102,36 @@ test
+ + + + hive + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + compatibility/src/test/scala + + + + + + + + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 334462357eb86..84d43eaeea51d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -253,9 +253,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, TimestampType) + ShortType, DecimalType, TimestampType, BinaryType) - protected def toHiveString(a: (Any, DataType)): String = a match { + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" @@ -269,6 +269,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString + case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala new file mode 100644 index 0000000000000..ad7dc0ecdb1bf --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.{io => hadoopIo} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types +import org.apache.spark.sql.catalyst.types._ + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[hive] trait HiveInspectors { + + def javaClassToDataType(clz: Class[_]): DataType = clz match { + // writable + case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType + case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType + case c: Class[_] if c == classOf[hadoopIo.Text] => StringType + case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType + case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType + case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType + case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType + case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType + + // java class + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[Array[Byte]] => BinaryType + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + + // primitive type + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + } + + /** Converts hive types to native catalyst types. */ + def unwrap(a: Any): Any = a match { + case null => null + case i: hadoopIo.IntWritable => i.get + case t: hadoopIo.Text => t.toString + case l: hadoopIo.LongWritable => l.get + case d: hadoopIo.DoubleWritable => d.get + case d: hiveIo.DoubleWritable => d.get + case s: hiveIo.ShortWritable => s.get + case b: hadoopIo.BooleanWritable => b.get + case b: hiveIo.ByteWritable => b.get + case b: hadoopIo.FloatWritable => b.get + case b: hadoopIo.BytesWritable => { + val bytes = new Array[Byte](b.getLength) + System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) + bytes + } + case t: hiveIo.TimestampWritable => t.getTimestamp + case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) + case list: java.util.List[_] => list.map(unwrap) + case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap + case array: Array[_] => array.map(unwrap).toSeq + case p: java.lang.Short => p + case p: java.lang.Long => p + case p: java.lang.Float => p + case p: java.lang.Integer => p + case p: java.lang.Double => p + case p: java.lang.Byte => p + case p: java.lang.Boolean => p + case str: String => str + case p: java.math.BigDecimal => p + case p: Array[Byte] => p + case p: java.sql.Timestamp => p + } + + def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case hvoi: HiveVarcharObjectInspector => + if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue + case hdoi: HiveDecimalObjectInspector => + if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case li: ListObjectInspector => + Option(li.getList(data)) + .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .orNull + case mi: MapObjectInspector => + Option(mi.getMap(data)).map( + _.map { + case (k,v) => + (unwrapData(k, mi.getMapKeyObjectInspector), + unwrapData(v, mi.getMapValueObjectInspector)) + }.toMap).orNull + case si: StructObjectInspector => + val allRefs = si.getAllStructFieldRefs + new GenericRow( + allRefs.map(r => + unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + /** Converts native catalyst types to the types expected by Hive */ + def wrap(a: Any): AnyRef = a match { + case s: String => new hadoopIo.Text(s) // TODO why should be Text? + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case f: Float => f: java.lang.Float + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case b: BigDecimal => b.bigDecimal + case b: Array[Byte] => b + case t: java.sql.Timestamp => t + case s: Seq[_] => seqAsJavaList(s.map(wrap)) + case m: Map[_,_] => + mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + case null => null + } + + def toInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toInspector(keyType), toInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + types.StructField( + f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case m: MapObjectInspector => + MapType( + inspectorToDataType(m.getMapKeyObjectInspector), + inspectorToDataType(m.getMapValueObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: JavaStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: JavaIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: JavaDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: JavaBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: JavaLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: JavaShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + case _: JavaByteObjectInspector => ByteType + case _: WritableFloatObjectInspector => FloatType + case _: JavaFloatObjectInspector => FloatType + case _: WritableBinaryObjectInspector => BinaryType + case _: JavaBinaryObjectInspector => BinaryType + case _: WritableHiveDecimalObjectInspector => DecimalType + case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableTimestampObjectInspector => TimestampType + case _: JavaTimestampObjectInspector => TimestampType + } + + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo._ + import TypeInfoFactory._ + + def toTypeInfo: TypeInfo = dt match { + case BinaryType => binaryTypeInfo + case BooleanType => booleanTypeInfo + case ByteType => byteTypeInfo + case DoubleType => doubleTypeInfo + case FloatType => floatTypeInfo + case IntegerType => intTypeInfo + case LongType => longTypeInfo + case ShortType => shortTypeInfo + case StringType => stringTypeInfo + case DecimalType => decimalTypeInfo + case TimestampType => timestampTypeInfo + case NullType => voidTypeInfo + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8db60d32767b5..156b090712df2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -258,7 +258,7 @@ private[hive] case class MetastoreRelation // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - def hiveQlTable = new Table(table) + @transient lazy val hiveQlTable = new Table(table) def hiveQlPartitions = partitions.map { p => new Partition(hiveQlTable, p) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 300e249f5b2e1..4395874526d51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command private[hive] case class SourceCommand(filePath: String) extends Command -private[hive] case class AddJar(jarPath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ @@ -229,7 +227,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8)) + NativeCommand(sql) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { @@ -932,6 +930,8 @@ private[hive] object HiveQl { /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fc33c5b460d70..057eb60a02612 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry - extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors { +private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { + + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is @@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( sys.error(s"Couldn't find function $name")) + val functionClassName = functionInfo.getFunctionClass.getName() + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val function = createFunction[UDF](name) + val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) HiveSimpleUdf( - name, + functionClassName, children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(name, children) + HiveGenericUdf(functionClassName, children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(name, children) + HiveGenericUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(name, Nil, children) + HiveGenericUdtf(functionClassName, Nil, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } } - - def javaClassToDataType(clz: Class[_]): DataType = clz match { - // writable - case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType - case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType - case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType - case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType - case c: Class[_] if c == classOf[hadoopIo.Text] => StringType - case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType - case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType - case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType - case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType - case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType - - // java class - case c: Class[_] if c == classOf[java.lang.String] => StringType - case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType - case c: Class[_] if c == classOf[Array[Byte]] => BinaryType - case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType - case c: Class[_] if c == classOf[java.lang.Long] => LongType - case c: Class[_] if c == classOf[java.lang.Double] => DoubleType - case c: Class[_] if c == classOf[java.lang.Byte] => ByteType - case c: Class[_] if c == classOf[java.lang.Float] => FloatType - case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType - - // primitive type - case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType - case c: Class[_] if c == java.lang.Long.TYPE => LongType - case c: Class[_] if c == java.lang.Double.TYPE => DoubleType - case c: Class[_] if c == java.lang.Byte.TYPE => ByteType - case c: Class[_] if c == java.lang.Float.TYPE => FloatType - case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType - - case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) - } } private[hive] trait HiveFunctionFactory { - def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) - def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass - def createFunction[UDFType](name: String) = - getFunctionClass(name).newInstance.asInstanceOf[UDFType] - - /** Converts hive types to native catalyst types. */ - def unwrap(a: Any): Any = a match { - case null => null - case i: hadoopIo.IntWritable => i.get - case t: hadoopIo.Text => t.toString - case l: hadoopIo.LongWritable => l.get - case d: hadoopIo.DoubleWritable => d.get - case d: hiveIo.DoubleWritable => d.get - case s: hiveIo.ShortWritable => s.get - case b: hadoopIo.BooleanWritable => b.get - case b: hiveIo.ByteWritable => b.get - case b: hadoopIo.FloatWritable => b.get - case b: hadoopIo.BytesWritable => { - val bytes = new Array[Byte](b.getLength) - System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) - bytes - } - case t: hiveIo.TimestampWritable => t.getTimestamp - case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) - case list: java.util.List[_] => list.map(unwrap) - case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap - case array: Array[_] => array.map(unwrap).toSeq - case p: java.lang.Short => p - case p: java.lang.Long => p - case p: java.lang.Float => p - case p: java.lang.Integer => p - case p: java.lang.Double => p - case p: java.lang.Byte => p - case p: java.lang.Boolean => p - case str: String => str - case p: java.math.BigDecimal => p - case p: Array[Byte] => p - case p: java.sql.Timestamp => p - } + val functionClassName: String + + def createFunction[UDFType]() = + getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType] } private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory { @@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu type UDFType type EvaluatedType = Any - val name: String - def nullable = true def references = children.flatMap(_.references).toSet - // FunctionInfo is not serializable so we must look it up here again. - lazy val functionInfo = getFunctionInfo(name) - lazy val function = createFunction[UDFType](name) + lazy val function = createFunction[UDFType]() - override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } -private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { +private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) + extends HiveUdf { + import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) } } -private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) +private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) extends HiveUdf with HiveInspectors { import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ @@ -277,131 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) } } -private[hive] trait HiveInspectors { - - def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { - case hvoi: HiveVarcharObjectInspector => - if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue - case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) - case li: ListObjectInspector => - Option(li.getList(data)) - .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) - .orNull - case mi: MapObjectInspector => - Option(mi.getMap(data)).map( - _.map { - case (k,v) => - (unwrapData(k, mi.getMapKeyObjectInspector), - unwrapData(v, mi.getMapValueObjectInspector)) - }.toMap).orNull - case si: StructObjectInspector => - val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) - } - - /** Converts native catalyst types to the types expected by Hive */ - def wrap(a: Any): AnyRef = a match { - case s: String => new hadoopIo.Text(s) // TODO why should be Text? - case i: Int => i: java.lang.Integer - case b: Boolean => b: java.lang.Boolean - case f: Float => f: java.lang.Float - case d: Double => d: java.lang.Double - case l: Long => l: java.lang.Long - case l: Short => l: java.lang.Short - case l: Byte => l: java.lang.Byte - case b: BigDecimal => b.bigDecimal - case b: Array[Byte] => b - case t: java.sql.Timestamp => t - case s: Seq[_] => seqAsJavaList(s.map(wrap)) - case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) - case null => null - } - - def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => - ObjectInspectorFactory.getStandardMapObjectInspector( - toInspector(keyType), toInspector(valueType)) - case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector - case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector - case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector - case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector - case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector - case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector - case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector - case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector - case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector - case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector - case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector - case StructType(fields) => - ObjectInspectorFactory.getStandardStructObjectInspector( - fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) - } - - def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { - case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { - types.StructField( - f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) - case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) - case m: MapObjectInspector => - MapType( - inspectorToDataType(m.getMapKeyObjectInspector), - inspectorToDataType(m.getMapValueObjectInspector)) - case _: WritableStringObjectInspector => StringType - case _: JavaStringObjectInspector => StringType - case _: WritableIntObjectInspector => IntegerType - case _: JavaIntObjectInspector => IntegerType - case _: WritableDoubleObjectInspector => DoubleType - case _: JavaDoubleObjectInspector => DoubleType - case _: WritableBooleanObjectInspector => BooleanType - case _: JavaBooleanObjectInspector => BooleanType - case _: WritableLongObjectInspector => LongType - case _: JavaLongObjectInspector => LongType - case _: WritableShortObjectInspector => ShortType - case _: JavaShortObjectInspector => ShortType - case _: WritableByteObjectInspector => ByteType - case _: JavaByteObjectInspector => ByteType - case _: WritableFloatObjectInspector => FloatType - case _: JavaFloatObjectInspector => FloatType - case _: WritableBinaryObjectInspector => BinaryType - case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType - case _: WritableTimestampObjectInspector => TimestampType - case _: JavaTimestampObjectInspector => TimestampType - } - - implicit class typeInfoConversions(dt: DataType) { - import org.apache.hadoop.hive.serde2.typeinfo._ - import TypeInfoFactory._ - - def toTypeInfo: TypeInfo = dt match { - case BinaryType => binaryTypeInfo - case BooleanType => booleanTypeInfo - case ByteType => byteTypeInfo - case DoubleType => doubleTypeInfo - case FloatType => floatTypeInfo - case IntegerType => intTypeInfo - case LongType => longTypeInfo - case ShortType => shortTypeInfo - case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo - case TimestampType => timestampTypeInfo - case NullType => voidTypeInfo - } - } -} - private[hive] case class HiveGenericUdaf( - name: String, + functionClassName: String, children: Seq[Expression]) extends AggregateExpression with HiveInspectors with HiveFunctionFactory { @@ -409,7 +207,7 @@ private[hive] case class HiveGenericUdaf( type UDFType = AbstractGenericUDAFResolver @transient - protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + protected lazy val resolver: AbstractGenericUDAFResolver = createFunction() @transient protected lazy val objectInspector = { @@ -426,9 +224,9 @@ private[hive] case class HiveGenericUdaf( def references: Set[Attribute] = children.map(_.references).flatten.toSet - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" - def newInstance() = new HiveUdafFunction(name, children, this) + def newInstance() = new HiveUdafFunction(functionClassName, children, this) } /** @@ -443,7 +241,7 @@ private[hive] case class HiveGenericUdaf( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUdtf( - name: String, + functionClassName: String, aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors with HiveFunctionFactory { @@ -451,7 +249,7 @@ private[hive] case class HiveGenericUdtf( override def references = children.flatMap(_.references).toSet @transient - protected lazy val function: GenericUDTF = createFunction(name) + protected lazy val function: GenericUDTF = createFunction() protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) @@ -506,11 +304,11 @@ private[hive] case class HiveGenericUdtf( } } - override def toString = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" } private[hive] case class HiveUdafFunction( - functionName: String, + functionClassName: String, exprs: Seq[Expression], base: AggregateExpression) extends AggregateFunction @@ -519,7 +317,7 @@ private[hive] case class HiveUdafFunction( def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + private val resolver = createFunction[AbstractGenericUDAFResolver]() private val inspectors = exprs.map(_.dataType).map(toInspector).toArray diff --git a/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 new file mode 100644 index 0000000000000..4d1ebdcde2c71 --- /dev/null +++ b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 @@ -0,0 +1 @@ +true true true true true true false false false false false false false false false false false false true true true true true true false false false false false false false false false false false false diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-0-efd135a811fa94760736a761d220b82 b/sql/hive/src/test/resources/golden/correlationoptimizer13-0-efd135a811fa94760736a761d220b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-1-32a82500cc28465fac6f64dde0c431c6 b/sql/hive/src/test/resources/golden/correlationoptimizer13-1-32a82500cc28465fac6f64dde0c431c6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-3-bb61d9292434f37bd386e5bff683764d b/sql/hive/src/test/resources/golden/correlationoptimizer13-3-bb61d9292434f37bd386e5bff683764d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-0-efd135a811fa94760736a761d220b82 b/sql/hive/src/test/resources/golden/correlationoptimizer9-0-efd135a811fa94760736a761d220b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-10-1190d82f88f7fb1f91968f6e2e03772a b/sql/hive/src/test/resources/golden/correlationoptimizer9-10-1190d82f88f7fb1f91968f6e2e03772a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 b/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 new file mode 100644 index 0000000000000..17c838bb62b3b --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 @@ -0,0 +1,9 @@ +103 val_103 103 val_103 4 4 +104 val_104 104 val_104 4 4 +105 val_105 105 val_105 1 1 +111 val_111 111 val_111 1 1 +113 val_113 113 val_113 4 4 +114 val_114 114 val_114 1 1 +116 val_116 116 val_116 1 1 +118 val_118 118 val_118 4 4 +119 val_119 119 val_119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-13-1190d82f88f7fb1f91968f6e2e03772a b/sql/hive/src/test/resources/golden/correlationoptimizer9-13-1190d82f88f7fb1f91968f6e2e03772a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 b/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 new file mode 100644 index 0000000000000..17c838bb62b3b --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 @@ -0,0 +1,9 @@ +103 val_103 103 val_103 4 4 +104 val_104 104 val_104 4 4 +105 val_105 105 val_105 1 1 +111 val_111 111 val_111 1 1 +113 val_113 113 val_113 4 4 +114 val_114 114 val_114 1 1 +116 val_116 116 val_116 1 1 +118 val_118 118 val_118 4 4 +119 val_119 119 val_119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-2-32a82500cc28465fac6f64dde0c431c6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-2-32a82500cc28465fac6f64dde0c431c6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-4-ec131bcf578dba99f20b16a7dc6b9b b/sql/hive/src/test/resources/golden/correlationoptimizer9-4-ec131bcf578dba99f20b16a7dc6b9b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 b/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 new file mode 100644 index 0000000000000..248a14f1f4a9f --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 @@ -0,0 +1,9 @@ +103 103 4 4 +104 104 4 4 +105 105 1 1 +111 111 1 1 +113 113 4 4 +114 114 1 1 +116 116 1 1 +118 118 4 4 +119 119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-7-f952899d70bd718cbdbc44a5290938c9 b/sql/hive/src/test/resources/golden/correlationoptimizer9-7-f952899d70bd718cbdbc44a5290938c9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 b/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 new file mode 100644 index 0000000000000..248a14f1f4a9f --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 @@ -0,0 +1,9 @@ +103 103 4 4 +104 104 4 4 +105 105 1 1 +111 111 1 1 +113 113 4 4 +114 114 1 1 +116 116 1 1 +118 118 4 4 +119 119 9 9 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 b/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 new file mode 100644 index 0000000000000..704f1e62f14c5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 @@ -0,0 +1,10 @@ +4 +4 +5 +4 +5 +5 +4 +4 +5 +4 diff --git a/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 b/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e new file mode 100644 index 0000000000000..b56757a60f780 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e @@ -0,0 +1,308 @@ +0 val_0 +2 val_2 +4 val_4 +5 val_5 +8 val_8 +9 val_9 +10 val_10 +11 val_11 +12 val_12 +15 val_15 +17 val_17 +18 val_18 +19 val_19 +20 val_20 +24 val_24 +26 val_26 +27 val_27 +28 val_28 +30 val_30 +33 val_33 +34 val_34 +35 val_35 +37 val_37 +41 val_41 +42 val_42 +43 val_43 +44 val_44 +47 val_47 +51 val_51 +53 val_53 +54 val_54 +57 val_57 +58 val_58 +64 val_64 +65 val_65 +66 val_66 +67 val_67 +69 val_69 +70 val_70 +72 val_72 +74 val_74 +76 val_76 +77 val_77 +78 val_78 +80 val_80 +82 val_82 +83 val_83 +84 val_84 +85 val_85 +86 val_86 +87 val_87 +90 val_90 +92 val_92 +95 val_95 +96 val_96 +97 val_97 +98 val_98 +100 val_100 +103 val_103 +104 val_104 +105 val_105 +111 val_111 +113 val_113 +114 val_114 +116 val_116 +118 val_118 +119 val_119 +120 val_120 +125 val_125 +126 val_126 +128 val_128 +129 val_129 +131 val_131 +133 val_133 +134 val_134 +136 val_136 +137 val_137 +138 val_138 +143 val_143 +145 val_145 +146 val_146 +149 val_149 +150 val_150 +152 val_152 +153 val_153 +155 val_155 +156 val_156 +157 val_157 +158 val_158 +160 val_160 +162 val_162 +163 val_163 +164 val_164 +165 val_165 +166 val_166 +167 val_167 +168 val_168 +169 val_169 +170 val_170 +172 val_172 +174 val_174 +175 val_175 +176 val_176 +177 val_177 +178 val_178 +179 val_179 +180 val_180 +181 val_181 +183 val_183 +186 val_186 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +192 val_192 +193 val_193 +194 val_194 +195 val_195 +196 val_196 +197 val_197 +199 val_199 +200 val_200 +201 val_201 +202 val_202 +203 val_203 +205 val_205 +207 val_207 +208 val_208 +209 val_209 +213 val_213 +214 val_214 +216 val_216 +217 val_217 +218 val_218 +219 val_219 +221 val_221 +222 val_222 +223 val_223 +224 val_224 +226 val_226 +228 val_228 +229 val_229 +230 val_230 +233 val_233 +235 val_235 +237 val_237 +238 val_238 +239 val_239 +241 val_241 +242 val_242 +244 val_244 +247 val_247 +248 val_248 +249 val_249 +252 val_252 +255 val_255 +256 val_256 +257 val_257 +258 val_258 +260 val_260 +262 val_262 +263 val_263 +265 val_265 +266 val_266 +272 val_272 +273 val_273 +274 val_274 +275 val_275 +277 val_277 +278 val_278 +280 val_280 +281 val_281 +282 val_282 +283 val_283 +284 val_284 +285 val_285 +286 val_286 +287 val_287 +288 val_288 +289 val_289 +291 val_291 +292 val_292 +296 val_296 +298 val_298 +305 val_305 +306 val_306 +307 val_307 +308 val_308 +309 val_309 +310 val_310 +311 val_311 +315 val_315 +316 val_316 +317 val_317 +318 val_318 +321 val_321 +322 val_322 +323 val_323 +325 val_325 +327 val_327 +331 val_331 +332 val_332 +333 val_333 +335 val_335 +336 val_336 +338 val_338 +339 val_339 +341 val_341 +342 val_342 +344 val_344 +345 val_345 +348 val_348 +351 val_351 +353 val_353 +356 val_356 +360 val_360 +362 val_362 +364 val_364 +365 val_365 +366 val_366 +367 val_367 +368 val_368 +369 val_369 +373 val_373 +374 val_374 +375 val_375 +377 val_377 +378 val_378 +379 val_379 +382 val_382 +384 val_384 +386 val_386 +389 val_389 +392 val_392 +393 val_393 +394 val_394 +395 val_395 +396 val_396 +397 val_397 +399 val_399 +400 val_400 +401 val_401 +402 val_402 +403 val_403 +404 val_404 +406 val_406 +407 val_407 +409 val_409 +411 val_411 +413 val_413 +414 val_414 +417 val_417 +418 val_418 +419 val_419 +421 val_421 +424 val_424 +427 val_427 +429 val_429 +430 val_430 +431 val_431 +432 val_432 +435 val_435 +436 val_436 +437 val_437 +438 val_438 +439 val_439 +443 val_443 +444 val_444 +446 val_446 +448 val_448 +449 val_449 +452 val_452 +453 val_453 +454 val_454 +455 val_455 +457 val_457 +458 val_458 +459 val_459 +460 val_460 +462 val_462 +463 val_463 +466 val_466 +467 val_467 +468 val_468 +469 val_469 +470 val_470 +472 val_472 +475 val_475 +477 val_477 +478 val_478 +479 val_479 +480 val_480 +481 val_481 +482 val_482 +483 val_483 +484 val_484 +485 val_485 +487 val_487 +489 val_489 +490 val_490 +491 val_491 +492 val_492 +493 val_493 +494 val_494 +495 val_495 +496 val_496 +497 val_497 +498 val_498 diff --git a/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 b/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff new file mode 100644 index 0000000000000..2d7022e386303 --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff @@ -0,0 +1,199 @@ +4 +5 +8 +9 +26 +27 +28 +30 +33 +34 +35 +37 +41 +42 +43 +44 +47 +51 +53 +54 +57 +58 +64 +65 +66 +67 +69 +70 +72 +74 +76 +77 +78 +80 +82 +83 +84 +85 +86 +87 +90 +92 +95 +96 +97 +98 +256 +257 +258 +260 +262 +263 +265 +266 +272 +273 +274 +275 +277 +278 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +291 +292 +296 +298 +302 +305 +306 +307 +308 +309 +310 +311 +315 +316 +317 +318 +321 +322 +323 +325 +327 +331 +332 +333 +335 +336 +338 +339 +341 +342 +344 +345 +348 +351 +353 +356 +360 +362 +364 +365 +366 +367 +368 +369 +373 +374 +375 +377 +378 +379 +382 +384 +386 +389 +392 +393 +394 +395 +396 +397 +399 +400 +401 +402 +403 +404 +406 +407 +409 +411 +413 +414 +417 +418 +419 +421 +424 +427 +429 +430 +431 +432 +435 +436 +437 +438 +439 +443 +444 +446 +448 +449 +452 +453 +454 +455 +457 +458 +459 +460 +462 +463 +466 +467 +468 +469 +470 +472 +475 +477 +478 +479 +480 +481 +482 +483 +484 +485 +487 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 diff --git a/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 b/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e new file mode 100644 index 0000000000000..bd545ccf7430c --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e @@ -0,0 +1,125 @@ +302 +305 +306 +307 +308 +309 +310 +311 +315 +316 +317 +318 +321 +322 +323 +325 +327 +331 +332 +333 +335 +336 +338 +339 +341 +342 +344 +345 +348 +351 +353 +356 +360 +362 +364 +365 +366 +367 +368 +369 +373 +374 +375 +377 +378 +379 +382 +384 +386 +389 +392 +393 +394 +395 +396 +397 +399 +400 +401 +402 +403 +404 +406 +407 +409 +411 +413 +414 +417 +418 +419 +421 +424 +427 +429 +430 +431 +432 +435 +436 +437 +438 +439 +443 +444 +446 +448 +449 +452 +453 +454 +455 +457 +458 +459 +460 +462 +463 +466 +467 +468 +469 +470 +472 +475 +477 +478 +479 +480 +481 +482 +483 +484 +485 +487 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 diff --git a/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 b/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff new file mode 100644 index 0000000000000..d77586c12b6af --- /dev/null +++ b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff @@ -0,0 +1,199 @@ +4 val_4 +5 val_5 +8 val_8 +9 val_9 +26 val_26 +27 val_27 +28 val_28 +30 val_30 +33 val_33 +34 val_34 +35 val_35 +37 val_37 +41 val_41 +42 val_42 +43 val_43 +44 val_44 +47 val_47 +51 val_51 +53 val_53 +54 val_54 +57 val_57 +58 val_58 +64 val_64 +65 val_65 +66 val_66 +67 val_67 +69 val_69 +70 val_70 +72 val_72 +74 val_74 +76 val_76 +77 val_77 +78 val_78 +80 val_80 +82 val_82 +83 val_83 +84 val_84 +85 val_85 +86 val_86 +87 val_87 +90 val_90 +92 val_92 +95 val_95 +96 val_96 +97 val_97 +98 val_98 +256 val_256 +257 val_257 +258 val_258 +260 val_260 +262 val_262 +263 val_263 +265 val_265 +266 val_266 +272 val_272 +273 val_273 +274 val_274 +275 val_275 +277 val_277 +278 val_278 +280 val_280 +281 val_281 +282 val_282 +283 val_283 +284 val_284 +285 val_285 +286 val_286 +287 val_287 +288 val_288 +289 val_289 +291 val_291 +292 val_292 +296 val_296 +298 val_298 +302 val_302 +305 val_305 +306 val_306 +307 val_307 +308 val_308 +309 val_309 +310 val_310 +311 val_311 +315 val_315 +316 val_316 +317 val_317 +318 val_318 +321 val_321 +322 val_322 +323 val_323 +325 val_325 +327 val_327 +331 val_331 +332 val_332 +333 val_333 +335 val_335 +336 val_336 +338 val_338 +339 val_339 +341 val_341 +342 val_342 +344 val_344 +345 val_345 +348 val_348 +351 val_351 +353 val_353 +356 val_356 +360 val_360 +362 val_362 +364 val_364 +365 val_365 +366 val_366 +367 val_367 +368 val_368 +369 val_369 +373 val_373 +374 val_374 +375 val_375 +377 val_377 +378 val_378 +379 val_379 +382 val_382 +384 val_384 +386 val_386 +389 val_389 +392 val_392 +393 val_393 +394 val_394 +395 val_395 +396 val_396 +397 val_397 +399 val_399 +400 val_400 +401 val_401 +402 val_402 +403 val_403 +404 val_404 +406 val_406 +407 val_407 +409 val_409 +411 val_411 +413 val_413 +414 val_414 +417 val_417 +418 val_418 +419 val_419 +421 val_421 +424 val_424 +427 val_427 +429 val_429 +430 val_430 +431 val_431 +432 val_432 +435 val_435 +436 val_436 +437 val_437 +438 val_438 +439 val_439 +443 val_443 +444 val_444 +446 val_446 +448 val_448 +449 val_449 +452 val_452 +453 val_453 +454 val_454 +455 val_455 +457 val_457 +458 val_458 +459 val_459 +460 val_460 +462 val_462 +463 val_463 +466 val_466 +467 val_467 +468 val_468 +469 val_469 +470 val_470 +472 val_472 +475 val_475 +477 val_477 +478 val_478 +479 val_479 +480 val_480 +481 val_481 +482 val_482 +483 val_483 +484 val_484 +485 val_485 +487 val_487 +489 val_489 +490 val_490 +491 val_491 +492 val_492 +493 val_493 +494 val_494 +495 val_495 +496 val_496 +497 val_497 +498 val_498 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 b/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-0-869726b703f160eabdb7763700b53e60 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-1-5644ab44e5ba9f2941216b8d5dc33a99 b/sql/hive/src/test/resources/golden/join_nullsafe-1-5644ab44e5ba9f2941216b8d5dc33a99 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 b/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 new file mode 100644 index 0000000000000..31c409082cc2f --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-10-b6de4e85dcc1d1949c7431d39fa1b919 @@ -0,0 +1,2 @@ +NULL 10 10 NULL NULL 10 +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-11-3aa243002a5363b84556736ef71613b1 b/sql/hive/src/test/resources/golden/join_nullsafe-11-3aa243002a5363b84556736ef71613b1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e b/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e new file mode 100644 index 0000000000000..9b77d13cbaab2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-12-3cc55b14e8256d2c51361b61986c291e @@ -0,0 +1,4 @@ +NULL NULL NULL NULL NULL NULL +NULL 10 10 NULL NULL 10 +10 NULL NULL 10 10 NULL +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 b/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 new file mode 100644 index 0000000000000..47c0709d39851 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-13-69d94e229191e7b9b1a3e7eae46eb993 @@ -0,0 +1,12 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +48 NULL NULL NULL +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c b/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c new file mode 100644 index 0000000000000..36ba48516b658 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-14-cf9ff6ee72a701a8e2f3e7fb0667903c @@ -0,0 +1,12 @@ +NULL NULL NULL NULL +NULL NULL NULL 35 +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da b/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da new file mode 100644 index 0000000000000..fc1fd198cf8be --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-15-507d0fa6d7ce39e2d9921555cea6f8da @@ -0,0 +1,13 @@ +NULL NULL NULL NULL +NULL NULL NULL 35 +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +48 NULL NULL NULL +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 b/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 new file mode 100644 index 0000000000000..1cc70524f9d6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-16-1c714fc339304de4db630530e5d1ce97 @@ -0,0 +1,11 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 b/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 new file mode 100644 index 0000000000000..1cc70524f9d6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-17-8a4b0dc781a28ad11a0db9805fe03aa8 @@ -0,0 +1,11 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-18-10b2051e65cac50ee1ea1c138ec192c8 b/sql/hive/src/test/resources/golden/join_nullsafe-18-10b2051e65cac50ee1ea1c138ec192c8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-19-23ab7ac8229a53d391195be7ca092429 b/sql/hive/src/test/resources/golden/join_nullsafe-19-23ab7ac8229a53d391195be7ca092429 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-2-793e288c9e0971f0bf3f37493f76dc7 b/sql/hive/src/test/resources/golden/join_nullsafe-2-793e288c9e0971f0bf3f37493f76dc7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-20-d6fc260320c577eec9a5db0d4135d224 b/sql/hive/src/test/resources/golden/join_nullsafe-20-d6fc260320c577eec9a5db0d4135d224 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-21-a60dae725ffc543f805242611d99de4e b/sql/hive/src/test/resources/golden/join_nullsafe-21-a60dae725ffc543f805242611d99de4e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-22-24c80d0f9e3d72c48d947770fa184985 b/sql/hive/src/test/resources/golden/join_nullsafe-22-24c80d0f9e3d72c48d947770fa184985 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-23-3fe6ae20cab3417759dcc654a3a26746 b/sql/hive/src/test/resources/golden/join_nullsafe-23-3fe6ae20cab3417759dcc654a3a26746 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 b/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-24-2db30531137611e06fdba478ca7a8412 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 b/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-25-e58b2754e8d9c56a473557a549d0d2b9 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 b/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-26-64cabe5164130a94f387288f37b62d71 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 b/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-27-e8ed4a1b574a6ca70cbfb3f7b9980aa6 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 b/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 new file mode 100644 index 0000000000000..2efbef0484452 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-28-5a0c946cd7033857ca99e5fb800f8525 @@ -0,0 +1,14 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf b/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-29-514043c2ddaf6ea8f16a764adc92d1cf @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-3-ae378fc0f875a21884e58fa35a6d52cd b/sql/hive/src/test/resources/golden/join_nullsafe-3-ae378fc0f875a21884e58fa35a6d52cd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 b/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-30-fcbf92cb1b85ab01102fbbc6caba9a88 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf b/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf new file mode 100644 index 0000000000000..66482299904bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-31-1cb03e1106f79d14f22bc89d386cedcf @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 10 +NULL NULL NULL 10 +NULL NULL NULL 35 +NULL NULL NULL 35 +NULL NULL NULL 110 +NULL NULL NULL 110 +NULL NULL NULL 135 +NULL NULL NULL 135 +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 NULL 10 +NULL 10 NULL 35 +NULL 10 NULL 110 +NULL 10 NULL 135 +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 NULL 10 +NULL 35 NULL 35 +NULL 35 NULL 110 +NULL 35 NULL 135 +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 NULL 10 +NULL 110 NULL 35 +NULL 110 NULL 110 +NULL 110 NULL 135 +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 NULL 10 +NULL 135 NULL 35 +NULL 135 NULL 110 +NULL 135 NULL 135 +10 NULL 10 NULL +48 NULL 48 NULL +100 100 100 100 +110 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 b/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 new file mode 100644 index 0000000000000..ea001a222f357 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-32-6a0bf6127d4b042e67ae8ee15125fb87 @@ -0,0 +1,40 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +100 100 100 100 +110 NULL NULL 110 +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c b/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c new file mode 100644 index 0000000000000..ea001a222f357 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-33-63157d43422fcedadba408537ccecd5c @@ -0,0 +1,40 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +100 100 100 100 +110 NULL NULL 110 +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 b/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 new file mode 100644 index 0000000000000..1093bd89f6e3f --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-34-9265f806b71c03061f93f9fbc88aa223 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +48 NULL NULL NULL +100 100 100 100 +110 NULL NULL 110 +148 NULL NULL NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc b/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc new file mode 100644 index 0000000000000..9cf0036674d6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-35-95815bafb81cccb8129c20d399a446fc @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL 35 +NULL NULL NULL 135 +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 10 110 NULL +NULL 10 148 NULL +NULL 35 NULL NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +NULL 35 110 NULL +NULL 35 148 NULL +NULL 110 NULL NULL +NULL 110 NULL NULL +NULL 110 10 NULL +NULL 110 48 NULL +NULL 110 110 NULL +NULL 110 148 NULL +NULL 135 NULL NULL +NULL 135 NULL NULL +NULL 135 10 NULL +NULL 135 48 NULL +NULL 135 110 NULL +NULL 135 148 NULL +10 NULL NULL 10 +100 100 100 100 +110 NULL NULL 110 +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 b/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-36-c4762c60cc93236b7647ebd32a40ce57 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 b/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-37-a87893adfc73c9cc63ceab200bb56245 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 b/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-38-e3dfe0044b44c8a49414479521acf762 @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec b/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec new file mode 100644 index 0000000000000..77f6a8ddd7c28 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-39-9a7e1f373b9c02e632d6c7c550b908ec @@ -0,0 +1,42 @@ +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL NULL 48 NULL +NULL NULL 110 NULL +NULL NULL 110 NULL +NULL NULL 148 NULL +NULL NULL 148 NULL +NULL 10 NULL 10 +NULL 35 NULL 35 +NULL 110 NULL 110 +NULL 135 NULL 135 +10 NULL NULL NULL +10 NULL NULL NULL +10 NULL 10 NULL +10 NULL 48 NULL +10 NULL 110 NULL +10 NULL 148 NULL +48 NULL NULL NULL +48 NULL NULL NULL +48 NULL 10 NULL +48 NULL 48 NULL +48 NULL 110 NULL +48 NULL 148 NULL +100 100 100 100 +110 NULL NULL NULL +110 NULL NULL NULL +110 NULL 10 NULL +110 NULL 48 NULL +110 NULL 110 NULL +110 NULL 148 NULL +148 NULL NULL NULL +148 NULL NULL NULL +148 NULL 10 NULL +148 NULL 48 NULL +148 NULL 110 NULL +148 NULL 148 NULL +200 200 200 200 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 b/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 new file mode 100644 index 0000000000000..1cc70524f9d6d --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-4-644c616d87ae426eb2f8c71638045185 @@ -0,0 +1,11 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL +10 NULL NULL 10 +100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-40-3c868718e4c120cb9a72ab7318c75be3 b/sql/hive/src/test/resources/golden/join_nullsafe-40-3c868718e4c120cb9a72ab7318c75be3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 b/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 new file mode 100644 index 0000000000000..421049d6e509e --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-41-1f7d8737c3e2d74d5ad865535d729811 @@ -0,0 +1,9 @@ +NULL NULL NULL NULL +NULL NULL 10 NULL +NULL NULL 48 NULL +NULL 10 NULL NULL +NULL 10 10 NULL +NULL 10 48 NULL +NULL 35 NULL NULL +NULL 35 10 NULL +NULL 35 48 NULL diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-5-1e393de94850e92b3b00536aacc9371f b/sql/hive/src/test/resources/golden/join_nullsafe-5-1e393de94850e92b3b00536aacc9371f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 b/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 new file mode 100644 index 0000000000000..aec3122cae5f9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-6-d66451815212e7d17744184e74c6b0a0 @@ -0,0 +1,2 @@ +10 NULL NULL 10 10 NULL +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-7-a3ad3cc301d9884898d3e6ab6c792d4c b/sql/hive/src/test/resources/golden/join_nullsafe-7-a3ad3cc301d9884898d3e6ab6c792d4c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac b/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac new file mode 100644 index 0000000000000..30db79efa79b4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/join_nullsafe-8-cc7527bcf746ab7e2cd9f28db0ead0ac @@ -0,0 +1,29 @@ +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL 10 +NULL NULL NULL NULL NULL 35 +NULL NULL 10 NULL NULL NULL +NULL NULL 10 NULL NULL 10 +NULL NULL 10 NULL NULL 35 +NULL NULL 48 NULL NULL NULL +NULL NULL 48 NULL NULL 10 +NULL NULL 48 NULL NULL 35 +NULL 10 NULL NULL NULL NULL +NULL 10 NULL NULL NULL 10 +NULL 10 NULL NULL NULL 35 +NULL 10 10 NULL NULL NULL +NULL 10 10 NULL NULL 10 +NULL 10 10 NULL NULL 35 +NULL 10 48 NULL NULL NULL +NULL 10 48 NULL NULL 10 +NULL 10 48 NULL NULL 35 +NULL 35 NULL NULL NULL NULL +NULL 35 NULL NULL NULL 10 +NULL 35 NULL NULL NULL 35 +NULL 35 10 NULL NULL NULL +NULL 35 10 NULL NULL 10 +NULL 35 10 NULL NULL 35 +NULL 35 48 NULL NULL NULL +NULL 35 48 NULL NULL 10 +NULL 35 48 NULL NULL 35 +10 NULL NULL 10 10 NULL +100 100 100 100 100 100 diff --git a/sql/hive/src/test/resources/golden/join_nullsafe-9-88f6f40959b0d2faabd9d4b3cd853809 b/sql/hive/src/test/resources/golden/join_nullsafe-9-88f6f40959b0d2faabd9d4b3cd853809 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 b/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 new file mode 100644 index 0000000000000..9b9b6312a269a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-0-36b6cdf7c5f68c91155569b1622f5876 @@ -0,0 +1 @@ +a = b - Returns TRUE if a equals b and false otherwise diff --git a/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 b/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 new file mode 100644 index 0000000000000..30fdf50f62e4e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-1-2422b50b96502dde8b661acdfebd8892 @@ -0,0 +1,2 @@ +a = b - Returns TRUE if a equals b and false otherwise +Synonyms: == diff --git a/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 b/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 new file mode 100644 index 0000000000000..d6b4c860778b7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-2-e0faab0f5e736c24bcc5503aeac55053 @@ -0,0 +1 @@ +a == b - Returns TRUE if a equals b and false otherwise diff --git a/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 b/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 new file mode 100644 index 0000000000000..71e55d6d638a6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-3-39d8d6f197803de927f0af5409ec2f33 @@ -0,0 +1,2 @@ +a == b - Returns TRUE if a equals b and false otherwise +Synonyms: = diff --git a/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 b/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 new file mode 100644 index 0000000000000..015c417bc68f0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-4-94ac2476006425e1b3bcddf29ad07b16 @@ -0,0 +1 @@ +false false true true NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 b/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 new file mode 100644 index 0000000000000..aa7b4b51edea7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-5-878650cf21e9360a07d204c8ffb0cde7 @@ -0,0 +1 @@ +a <=> b - Returns same result with EQUAL(=) operator for non-null operands, but returns TRUE if both are NULL, FALSE if one of the them is NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e b/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e new file mode 100644 index 0000000000000..aa7b4b51edea7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-6-1635ef051fecdfc7891d9f5a9a3a545e @@ -0,0 +1 @@ +a <=> b - Returns same result with EQUAL(=) operator for non-null operands, but returns TRUE if both are NULL, FALSE if one of the them is NULL diff --git a/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 b/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 new file mode 100644 index 0000000000000..05292fb23192d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_equal-7-78f1b96c199e307714fa1b804e5bae27 @@ -0,0 +1 @@ +false false true true true false false false false diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index eb7df717284ce..8489f2a34e63c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -30,6 +30,18 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("boolean = number", + """ + |SELECT + | 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y, + | 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y, + | 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y, + | 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y, + | 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y, + | 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y + |FROM src LIMIT 1 + """.stripMargin) + test("CREATE TABLE AS runs once") { hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect() assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1, @@ -404,10 +416,10 @@ class HiveQuerySuite extends HiveComparisonTest { hql(s"set $testKey=$testVal") assert(get(testKey, testVal + "_") == testVal) - hql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - hql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + hql("set some.property=20") + assert(get("some.property", "0") == "20") + hql("set some.property = 40") + assert(get("some.property", "0") == "40") hql(s"set $testKey=$testVal") assert(get(testKey, "0") == testVal) @@ -421,63 +433,61 @@ class HiveQuerySuite extends HiveComparisonTest { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet clear() // "set" itself returns all config variables currently specified in SQLConf. assert(hql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } hql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + hql(s"SET").collect().map(_.getString(0)) } // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + hql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + hql(s"SET $nonexistentKey").collect().map(_.getString(0)) } // Assert that sql() should have the same effects as hql() by repeating the above using sql(). clear() assert(sql("SET").collect().size == 0) - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey=$testVal").collect().map(_.getString(0)) } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal")) { + sql("SET").collect().map(_.getString(0)) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) { + sql("SET").collect().map(_.getString(0)) } - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + assertResult(Array(s"$testKey=$testVal")) { + sql(s"SET $testKey").collect().map(_.getString(0)) } - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + assertResult(Array(s"$nonexistentKey=")) { + sql(s"SET $nonexistentKey").collect().map(_.getString(0)) } clear() diff --git a/streaming/pom.xml b/streaming/pom.xml index f60697ce745b7..b99f306b8f2cc 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -28,7 +28,7 @@ org.apache.spark spark-streaming_2.10 - streaming + streaming jar Spark Project Streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 40da31318942e..1a47089e513c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -133,17 +133,17 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( val numOldValues = oldRDDs.size val numNewValues = newRDDs.size - val mergeValues = (seqOfValues: Seq[Seq[V]]) => { - if (seqOfValues.size != 1 + numOldValues + numNewValues) { + val mergeValues = (arrayOfValues: Array[Iterable[V]]) => { + if (arrayOfValues.size != 1 + numOldValues + numNewValues) { throw new Exception("Unexpected number of sequences of reduced values") } // Getting reduced values "old time steps" that will be removed from current window - val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) + val oldValues = (1 to numOldValues).map(i => arrayOfValues(i)).filter(!_.isEmpty).map(_.head) // Getting reduced values "new time steps" val newValues = - (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + (1 to numNewValues).map(i => arrayOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) - if (seqOfValues(0).isEmpty) { + if (arrayOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { throw new Exception("Neither previous window has value for key, nor new values found. " + @@ -153,7 +153,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( newValues.reduce(reduceF) // return } else { // Get the previous window's reduced value - var tempValue = seqOfValues(0).head + var tempValue = arrayOfValues(0).head // If old values exists, then inverse reduce then from previous value if (!oldValues.isEmpty) { tempValue = invReduceF(tempValue, oldValues.reduce(reduceF)) @@ -166,7 +166,8 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( } } - val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) + val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K, Array[Iterable[V]])]] + .mapValues(mergeValues) if (filterFunc.isDefined) { Some(mergedValuesRDD.filter(filterFunc.get)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index 743be58950c09..1868a1ebc7b4a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -68,13 +68,13 @@ object ActorSupervisorStrategy { * should be same. */ @DeveloperApi -trait ActorHelper { +trait ActorHelper extends Logging{ self: Actor => // to ensure that this can be added to Actor classes only /** Store an iterator of received data as a data block into Spark's memory. */ def store[T](iter: Iterator[T]) { - println("Storing iterator") + logDebug("Storing iterator") context.parent ! IteratorData(iter) } @@ -84,6 +84,7 @@ trait ActorHelper { * that Spark is configured to use. */ def store(bytes: ByteBuffer) { + logDebug("Storing Bytes") context.parent ! ByteBufferData(bytes) } @@ -93,7 +94,7 @@ trait ActorHelper { * being pushed into Spark's memory. */ def store[T](item: T) { - println("Storing item") + logDebug("Storing item") context.parent ! SingleItemData(item) } } @@ -157,15 +158,16 @@ private[streaming] class ActorReceiver[T: ClassTag]( def receive = { case IteratorData(iterator) => - println("received iterator") + logDebug("received iterator") store(iterator.asInstanceOf[Iterator[T]]) case SingleItemData(msg) => - println("received single") + logDebug("received single") store(msg.asInstanceOf[T]) n.incrementAndGet case ByteBufferData(bytes) => + logDebug("received bytes") store(bytes) case props: Props => diff --git a/tools/pom.xml b/tools/pom.xml index c0ee8faa7a615..97abb6b2b63e0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -27,7 +27,7 @@ org.apache.spark spark-tools_2.10 - tools + tools jar Spark Project Tools diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 03a73f92b275e..566983675bff5 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -99,9 +99,25 @@ object GenerateMIMAIgnore { (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) } + /** Scala reflection does not let us see inner function even if they are upgraded + * to public for some reason. So had to resort to java reflection to get all inner + * functions with $$ in there name. + */ + def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { + try { + Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) + .filter(_.contains("$$")).map(classSymbol.fullName + "." + _) + } catch { + case t: Throwable => + println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName) + Seq.empty[String] + } + } + private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { classSymbol.typeSignature.members - .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) + .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++ + getInnerFunctions(classSymbol) } def main(args: Array[String]) { @@ -121,7 +137,8 @@ object GenerateMIMAIgnore { name.endsWith("$class") || name.contains("$sp") || name.contains("hive") || - name.contains("Hive") + name.contains("Hive") || + name.contains("repl") } /** diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 5b13a1f002d6e..51744ece0412d 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-alpha + yarn-alpha org.apache.spark diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3ec36487dcd26..62b5c3bc5f0f3 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -60,6 +60,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ + private var uiHistoryAddress: String = _ private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) private var isLastAMRetry: Boolean = true @@ -237,6 +238,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, if (null != sparkContext) { uiAddress = sparkContext.ui.appUIHostPort + uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, resourceManager, @@ -360,7 +362,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, finishReq.setAppAttemptId(appAttemptId) finishReq.setFinishApplicationStatus(status) finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) + finishReq.setTrackingUrl(uiHistoryAddress) resourceManager.finishApplicationMaster(finishReq) } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index a86ad256dfa39..184e2ad6c82cd 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -28,7 +28,6 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ -import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -57,10 +56,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed:Boolean = false + + private var driverClosed: Boolean = false + private var isFinished: Boolean = false + private var registered: Boolean = false + + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) val securityManager = new SecurityManager(sparkConf) - val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, + val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ @@ -97,23 +103,26 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() - val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() - - // Compute number of threads for akka - val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - - if (minimumMemory > 0) { - val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnAllocationHandler.MEMORY_OVERHEAD) - val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore + synchronized { + if (!isFinished) { + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnAllocationHandler.MEMORY_OVERHEAD) + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + registered = true } } - waitForSparkMaster() addAmIpFilter() // Allocate all containers @@ -243,11 +252,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) + checkNumExecutorsFailed() Thread.sleep(100) } logInfo("All executors have launched.") - + } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } } // TODO: We might want to extend this to allocate more containers in case they die ! @@ -257,6 +272,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { while (!driverClosed) { + checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { logInfo("Allocating " + missingExecutorCount + @@ -282,15 +298,23 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp yarnAllocator.allocateContainers(0) } - def finishApplicationMaster(status: FinalApplicationStatus) { - - logInfo("finish ApplicationMaster with " + status) - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) - resourceManager.finishApplicationMaster(finishReq) + def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { + synchronized { + if (isFinished) { + return + } + logInfo("Unregistering ApplicationMaster with " + status) + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + finishReq.setFinishApplicationStatus(status) + finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", "")) + finishReq.setDiagnostics(appMessage) + resourceManager.finishApplicationMaster(finishReq) + } + isFinished = true + } } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 556f49342977a..a1298e8f30b5c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -232,7 +232,8 @@ trait ClientBase extends Logging { if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { val setPermissions = if (destName.equals(ClientBase.APP_JAR)) true else false val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + val destFs = FileSystem.get(destPath.toUri(), conf) + distCacheMgr.addResource(destFs, conf, destPath, localResources, LocalResourceType.FILE, destName, statCache) } else if (confKey != null) { sparkConf.set(confKey, localPath) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 718cb19f57261..e98308cdbd74e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,6 +30,9 @@ import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.SparkHadoopUtil /** @@ -132,4 +135,17 @@ object YarnSparkHadoopUtil { } } + def getUIHistoryAddress(sc: SparkContext, conf: SparkConf) : String = { + val eventLogDir = sc.eventLogger match { + case Some(logger) => logger.getApplicationLogDir() + case None => "" + } + val historyServerAddress = conf.get("spark.yarn.historyServer.address", "") + if (historyServerAddress != "" && eventLogDir != "") { + historyServerAddress + HistoryServer.UI_PATH_PREFIX + s"/$eventLogDir" + } else { + "" + } + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d8266f7b0c9a7..f8fb96b312f23 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import scala.collection.mutable.ArrayBuffer @@ -37,6 +37,8 @@ private[spark] class YarnClientSchedulerBackend( var client: Client = null var appId: ApplicationId = null + var checkerThread: Thread = null + var stopping: Boolean = false private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -54,6 +56,7 @@ private[spark] class YarnClientSchedulerBackend( val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) + conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( @@ -85,6 +88,7 @@ private[spark] class YarnClientSchedulerBackend( client = new Client(args, conf) appId = client.runApp() waitForApp() + checkerThread = yarnApplicationStateCheckerThread() } def waitForApp() { @@ -115,7 +119,32 @@ private[spark] class YarnClientSchedulerBackend( } } + private def yarnApplicationStateCheckerThread(): Thread = { + val t = new Thread { + override def run() { + while (!stopping) { + val report = client.getApplicationReport(appId) + val state = report.getYarnApplicationState() + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED + || state == YarnApplicationState.FAILED) { + logError(s"Yarn application already ended: $state") + sc.stop() + stopping = true + } + Thread.sleep(1000L) + } + checkerThread = null + Thread.currentThread().interrupt() + } + } + t.setName("Yarn Application State Checker") + t.setDaemon(true) + t.start() + t + } + override def stop() { + stopping = true super.stop() client.stop logInfo("Stopped") diff --git a/yarn/pom.xml b/yarn/pom.xml index efb473aa1b261..3faaf053634d6 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -29,7 +29,7 @@ pom Spark Project YARN Parent POM - yarn + yarn diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index ceaf9f9d71001..b6c8456d06684 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -24,7 +24,7 @@ ../pom.xml - yarn-stable + yarn-stable org.apache.spark diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index eaf594c8b49b9..035356d390c80 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -59,6 +59,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ + private var uiHistoryAddress: String = _ private val maxAppAttempts: Int = conf.getInt( YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) private var isLastAMRetry: Boolean = true @@ -216,6 +217,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, if (sparkContext != null) { uiAddress = sparkContext.ui.appUIHostPort + uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, amClient, @@ -312,8 +314,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, logInfo("Unregistering ApplicationMaster with " + status) if (registered) { - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl) + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 5ac95f3798723..fc7b8320d734d 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -19,15 +19,12 @@ package org.apache.spark.deploy.yarn import java.net.Socket import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ -import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -57,10 +54,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var yarnAllocator: YarnAllocationHandler = _ - private var driverClosed:Boolean = false + private var driverClosed: Boolean = false + private var isFinished: Boolean = false + private var registered: Boolean = false private var amClient: AMRMClient[ContainerRequest] = _ + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + val securityManager = new SecurityManager(sparkConf) val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 @@ -101,7 +104,12 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp amClient.start() appAttemptId = ApplicationMaster.getApplicationAttemptId() - registerApplicationMaster() + synchronized { + if (!isFinished) { + registerApplicationMaster() + registered = true + } + } waitForSparkMaster() addAmIpFilter() @@ -210,6 +218,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(100) @@ -228,12 +237,20 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } + } + private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime val t = new Thread { override def run() { while (!driverClosed) { + checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") yarnAllocator.allocateResources() @@ -248,10 +265,18 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp t } - def finishApplicationMaster(status: FinalApplicationStatus) { - logInfo("Unregistering ApplicationMaster with " + status) - val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") - amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl) + def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") { + synchronized { + if (isFinished) { + return + } + logInfo("Unregistering ApplicationMaster with " + status) + if (registered) { + val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "") + amClient.unregisterApplicationMaster(status, appMessage, trackingUrl) + } + isFinished = true + } } }