diff --git a/assembly/pom.xml b/assembly/pom.xml index b2a9d0780ee2b..1bb5a671f5390 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -43,12 +43,6 @@ - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -133,20 +127,6 @@ shade - - - com.google - org.spark-project.guava - - com.google.common.** - - - com/google/common/base/Absent* - com/google/common/base/Optional* - com/google/common/base/Present* - - - diff --git a/bin/spark-class b/bin/spark-class index 1b945461fabc8..2f0441bb3c1c2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" . "$FWDIR"/bin/load-spark-env.sh @@ -120,8 +121,8 @@ fi JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" +if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then + JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`" fi # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! diff --git a/build/mvn b/build/mvn index 43471f83e904c..a87c5a26230c8 100755 --- a/build/mvn +++ b/build/mvn @@ -48,11 +48,11 @@ install_app() { # check if we already have the tarball # check if we have curl installed # download application - [ ! -f "${local_tarball}" ] && [ -n "`which curl 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ echo "exec: curl ${curl_opts} ${remote_tarball}" && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers - [ ! -f "${local_tarball}" ] && [ -n "`which wget 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ echo "exec: wget ${wget_opts} ${remote_tarball}" && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit @@ -68,10 +68,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { install_app \ - "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \ - "apache-maven-3.2.3-bin.tar.gz" \ - "apache-maven-3.2.3/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ + "apache-maven-3.2.5-bin.tar.gz" \ + "apache-maven-3.2.5/bin/mvn" + MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" } # Install zinc under the build/ folder diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index f5df439effb01..5e0c640fa5919 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -50,9 +50,9 @@ acquire_sbt_jar () { # Download printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then + if [ $(command -v curl) ]; then (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" - elif hash wget 2>/dev/null; then + elif [ $(command -v wget) ]; then (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" diff --git a/core/pom.xml b/core/pom.xml index d9a49c9e08afc..31e919a1c831a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,10 @@ Spark Project Core http://spark.apache.org/ + + com.google.guava + guava + com.twitter chill_${scala.binary.version} @@ -106,16 +110,6 @@ org.eclipse.jetty jetty-server - - - com.google.guava - guava - compile - org.apache.commons commons-lang3 @@ -350,42 +344,6 @@ true - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - - - false - - - com.google.guava:guava - - - - - - com.google.guava:guava - - com/google/common/base/Absent* - com/google/common/base/Optional* - com/google/common/base/Present* - - - - - - - - org.apache.maven.plugins maven-dependency-plugin diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index c99a61f63ea2b..89eec7d4b7f61 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,4 +10,3 @@ log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO -log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index a1f7133f897ee..f23ba9dba167f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -190,6 +190,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, +.getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index d4f2624061e35..419d093d55643 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -118,15 +118,17 @@ trait Logging { // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently // org.apache.logging.slf4j.Log4jLoggerFactory val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) - val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4j12Initialized && usingLog4j12) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } Logging.initialized = true diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f9d4aa4240e9d..cd91c8f87547b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -17,9 +17,11 @@ package org.apache.spark +import java.util.concurrent.ConcurrentHashMap + import scala.collection.JavaConverters._ -import scala.collection.concurrent.TrieMap -import scala.collection.mutable.{HashMap, LinkedHashSet} +import scala.collection.mutable.LinkedHashSet + import org.apache.spark.serializer.KryoSerializer /** @@ -47,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private[spark] val settings = new TrieMap[String, String]() + private val settings = new ConcurrentHashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { - settings(k) = v + set(k, v) } } @@ -64,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } - settings(key) = value + settings.put(key, value) this } @@ -130,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]) = { - this.settings ++= settings + this.settings.putAll(settings.toMap.asJava) this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - if (!settings.contains(key)) { - settings(key) = value - } + settings.putIfAbsent(key, value) this } @@ -164,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Get a parameter; throws a NoSuchElementException if it's not set */ def get(key: String): String = { - settings.getOrElse(key, throw new NoSuchElementException(key)) + getOption(key).getOrElse(throw new NoSuchElementException(key)) } /** Get a parameter, falling back to a default if not set */ def get(key: String, defaultValue: String): String = { - settings.getOrElse(key, defaultValue) + getOption(key).getOrElse(defaultValue) } /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - settings.get(key) + Option(settings.get(key)) } /** Get all parameters as a list of pairs */ - def getAll: Array[(String, String)] = settings.toArray + def getAll: Array[(String, String)] = { + settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray + } /** Get a parameter as an integer, falling back to a default if not set */ def getInt(key: String, defaultValue: Int): Int = { @@ -225,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getAppId: String = get("spark.app.id") /** Does the configuration contain a given parameter? */ - def contains(key: String): Boolean = settings.contains(key) + def contains(key: String): Boolean = settings.containsKey(key) /** Copy this object */ override def clone: SparkConf = { - new SparkConf(false).setAll(settings) + new SparkConf(false).setAll(getAll) } /** @@ -241,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { - if (settings.contains("spark.local.dir")) { + if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." logWarning(msg) @@ -266,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate spark.executor.extraJavaOptions - settings.get(executorOptsKey).map { javaOpts => + getOption(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." @@ -346,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * configuration out for debugging. */ def toDebugString: String = { - settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") + getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n") } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a354ed4d1486..3c61c10820ba9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() + @volatile private var stopped: Boolean = false + + private def assertNotStopped(): Unit = { + if (stopped) { + throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + } + } + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * modified collection. Pass a copy of the argument to avoid this. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + assertNotStopped() new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } @@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { + assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Hadoop-supported file system URI, and return it as an RDD of Strings. */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { + assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minPartitions).map(pair => pair._2.toString).setName(path) } @@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { + assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) val updateConf = job.getConfiguration @@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli @Experimental def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) : RDD[Array[Byte]] = { + assertNotStopped() conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, classOf[FixedLengthBinaryInputFormat], @@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) @@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions ): RDD[(K, V)] = { + assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + assertNotStopped() val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() new NewHadoopRDD(this, fClass, kClass, vClass, conf) } @@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int ): RDD[(K, V)] = { + assertNotStopped() val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } @@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. * */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V] - ): RDD[(K, V)] = + def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = { + assertNotStopped() sequenceFile(path, keyClass, valueClass, defaultMinPartitions) + } /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { + assertNotStopped() val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] @@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions ): RDD[T] = { + assertNotStopped() sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader)) } @@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * The variable will be sent to each cluster only once. */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { + assertNotStopped() + if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have created RDD broadcast variables but not used them: + logWarning("Can not directly broadcast RDDs; instead, call collect() and " + + "broadcast the result (see SPARK-5063)") + } val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) @@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * memory available for caching. */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + assertNotStopped() env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.host + ":" + blockManagerId.port, mem) } @@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + assertNotStopped() val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) @@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getExecutorStorageStatus: Array[StorageStatus] = { + assertNotStopped() env.blockManager.master.getStorageStatus } @@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getAllPools: Seq[Schedulable] = { + assertNotStopped() // TODO(xiajunluan): We should take nested pools into account taskScheduler.rootPool.schedulableQueue.toSeq } @@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getPoolForName(pool: String): Option[Schedulable] = { + assertNotStopped() Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) } @@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Return current scheduling mode */ def getSchedulingMode: SchedulingMode.SchedulingMode = { + assertNotStopped() taskScheduler.schedulingMode } @@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { postApplicationEnd() ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { + if (!stopped) { + stopped = true env.metricsSystem.report() metadataCleaner.cancel() env.actorSystem.stop(heartbeatReceiver) cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() + dagScheduler.stop() + dagScheduler = null taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - if (dagScheduler == null) { - throw new SparkException("SparkContext has been shutdown") + if (stopped) { + throw new IllegalStateException("SparkContext has been shutdown") } val callSite = getCallSite val cleanedFunc = clean(func) @@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { + assertNotStopped() val callSite = getCallSite logInfo("Starting job: " + callSite.shortForm) val start = System.nanoTime @@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + assertNotStopped() val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * for more information. */ def cancelJobGroup(groupId: String) { + assertNotStopped() dagScheduler.cancelJobGroup(groupId) } /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { + assertNotStopped() dagScheduler.cancelAllJobs() } @@ -1468,13 +1505,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getCheckpointDir = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = { + assertNotStopped() + taskScheduler.defaultParallelism + } /** Default min number of partitions for Hadoop RDDs when not given by user */ @deprecated("use defaultMinPartitions", "1.0.0") def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** Default min number of partitions for Hadoop RDDs when not given by user */ + /** + * Default min number of partitions for Hadoop RDDs when not given by user + * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. + * The reasons for this are discussed in https://github.com/mesos/spark/pull/718 + */ def defaultMinPartitions: Int = math.min(defaultParallelism, 2) private val nextShuffleId = new AtomicInteger(0) @@ -1942,7 +1986,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33f..1264a8126153b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -326,6 +326,10 @@ object SparkEnv extends Logging { // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { + // We need to set the executor ID before the MetricsSystem is created because sources and + // sinks specified in the metrics configuration file will want to incorporate this executor's + // ID into the metrics they report. + conf.set("spark.executor.id", executorId) val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) ms.start() ms diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 62bf18d82d9b0..0f91c942ecd50 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -348,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]] + */ + def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + */ + def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -369,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]] + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U], + depth: Int): U = { + rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U]) + } + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = { + treeAggregate(zeroValue, seqOp, combOp, 2) + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 5ba66178e2b78..c9181a29d4756 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -138,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable + case array: Array[Any] => { + val arrayWriteable = new ArrayWritable(classOf[Writable]) + arrayWriteable.set(array.map(convertToWritable(_))) + arrayWriteable + } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4ac666c54fbcd..119e0459c5d1b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -67,17 +67,16 @@ private[spark] class PythonRDD( envVars += ("SPARK_REUSE_WORKER" -> "1") } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + // Whether is the worker released into idle pool + @volatile var released = false // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() writerThread.join() - if (reuse_worker && complete_cleanly) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) - } else { + if (!reuse_worker || !released) { try { worker.close() } catch { @@ -145,8 +144,12 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - complete_cleanly = true + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + released = true + } } null } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index a4153aaa926f8..fb52a960e0765 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -153,7 +153,10 @@ private[spark] object SerDeUtil extends Logging { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case array: Array[Any] => array.toSeq + case _ => obj.asInstanceOf[JArrayList[_]].asScala + } } else { Seq(obj) } @@ -199,7 +202,10 @@ private[spark] object SerDeUtil extends Logging { * representation is serialized */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { - val (keyFailed, valueFailed) = checkPickle(rdd.first()) + val (keyFailed, valueFailed) = rdd.take(1) match { + case Array() => (false, false) + case Array(first) => checkPickle(first) + } rdd.mapPartitions { iter => val cleaned = iter.map { case (k, v) => @@ -226,10 +232,12 @@ private[spark] object SerDeUtil extends Logging { } val rdd = pythonToJava(pyRDD, batched).rdd - rdd.first match { - case obj if isPair(obj) => + rdd.take(1) match { + case Array(obj) if isPair(obj) => // we only accept (K, V) - case other => throw new SparkException( + case Array() => + // we also accept empty collections + case Array(other) => throw new SparkException( s"RDD element of type ${other.getClass.getName} cannot be used") } rdd.map { obj => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 57f9faf5ddd1d..211e3ede53d9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() @@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesWritten = f() @@ -172,10 +170,8 @@ class SparkHadoopUtil extends Logging { } } - private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + private def getFileSystemThreadStatistics(): Seq[AnyRef] = { + val stats = FileSystem.getAllStatistics() stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 2b084a2d73b78..0ae45f4ad9130 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (!logInfos.isEmpty) { val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def addIfAbsent(info: FsApplicationHistoryInfo) = { - if (!newApps.contains(info.id)) { + if (!newApps.contains(info.id) || + newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) && + !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) { newApps += (info.id -> info) } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d71..823825302658c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend( } case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) + if (x.remoteAddress == driver.anchorPath.address) { + logError(s"Driver $x disassociated! Shutting down.") + System.exit(1) + } else { + logWarning(s"Received irrelevant DisassociatedEvent $x") + } case StopExecutor => logInfo("Driver commanded a shutdown") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42566d1a14093..312bb3a1daaa3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} */ private[spark] class Executor( executorId: String, - slaveHostname: String, + executorHostname: String, env: SparkEnv, isLocal: Boolean = false) extends Logging { + + logInfo(s"Starting executor ID $executorId on host $executorHostname") + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -58,12 +61,12 @@ private[spark] class Executor( @volatile private var isStopped = false // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) + assert (0 == Utils.parseHostPort(executorHostname)._2) // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) + Utils.setCustomHostname(executorHostname) if (!isLocal) { // Setup an uncaught exception handler for non-local mode. @@ -73,7 +76,6 @@ private[spark] class Executor( } val executorSource = new ExecutorSource(this, executorId) - conf.set("spark.executor.id", executorId) if (!isLocal) { env.metricsSystem.registerSource(executorSource) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ddb5903bf6875..97912c68c5982 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.executor import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.DataReadMethod import org.apache.spark.executor.DataReadMethod.DataReadMethod import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 45633e3de01dd..83e8eb71260eb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -130,8 +130,8 @@ private[spark] class MetricsSystem private ( if (appId.isDefined && executorId.isDefined) { MetricRegistry.name(appId.get, executorId.get, source.sourceName) } else { - // Only Driver and Executor are set spark.app.id and spark.executor.id. - // For instance, Master and Worker are not related to a specific application. + // Only Driver and Executor set spark.app.id and spark.executor.id. + // Other instance types, e.g. Master and Worker, are not related to a specific application. val warningMsg = s"Using default name $defaultName for source because %s is not set." if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 03c4137ca0a81..ee22c6656e69e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -184,14 +184,16 @@ private[nio] class ConnectionManager( // to be able to track asynchronous messages private val idCount: AtomicInteger = new AtomicInteger(1) + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) + // start this thread last, since it invokes run(), which accesses members above selectorThread.start() - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private def triggerWrite(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) if (conn == null) return @@ -232,7 +234,6 @@ private[nio] class ConnectionManager( } ) } - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private def triggerRead(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 056aef0bc210a..c3e3931042de2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -218,13 +219,13 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.inputSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null @@ -254,7 +255,8 @@ class HadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + } else if (split.inputSplit.value.isInstanceOf[FileSplit] || + split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b0e3c87ccff4..d86f95ac3e485 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat @@ -34,7 +34,7 @@ import org.apache.spark.Logging import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils @@ -114,13 +114,13 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.serializableHadoopSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) @@ -163,7 +163,8 @@ class NewHadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0f37d830ef34f..49b88a90ab5af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { @@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1086,11 +1086,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) - : (OutputMetrics, Option[() => Long]) = { - val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) - .map(new Path(_)) - .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) if (bytesWrittenCallback.isDefined) { context.taskMetrics.outputMetrics = Some(outputMetrics) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 97012c7033f9f..5f39384975f9b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli * on RDD internals. */ abstract class RDD[T: ClassTag]( - @transient private var sc: SparkContext, + @transient private var _sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { + if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) { + // This is a warning instead of an exception in order to avoid breaking user programs that + // might have defined nested RDDs without running jobs with them. + logWarning("Spark does not support nested RDDs (see SPARK-5063)") + } + + private def sc: SparkContext = { + if (_sc == null) { + throw new SparkException( + "RDD transformations and actions can only be invoked by the driver, not inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + + "the values transformation and count action cannot be performed inside of the rdd1.map " + + "transformation. For more information, see SPARK-5063.") + } + _sc + } + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) @@ -883,6 +900,38 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -918,6 +967,37 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (partitions.size == 0) { + return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index e5d1eb767e109..8f5ceaa5de515 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -91,11 +91,11 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo) +case class SparkListenerExecutorAdded(time: Long, executorId: String, executorInfo: ExecutorInfo) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorRemoved(executorId: String) +case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5786d367464f4..103a5c053c289 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -108,7 +108,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } - listenerBus.post(SparkListenerExecutorAdded(executorId, data)) + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() } @@ -216,7 +217,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) - listenerBus.post(SparkListenerExecutorRemoved(executorId)) + listenerBus.post( + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) case None => logError(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 79c9051e88691..c3c546be6da15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -269,7 +269,7 @@ private[spark] class MesosSchedulerBackend( mesosTasks.foreach { case (slaveId, tasks) => slaveIdToWorkerOffer.get(slaveId).foreach(o => - listenerBus.post(SparkListenerExecutorAdded(slaveId, + listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, new ExecutorInfo(o.host, o.cores))) ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) @@ -327,7 +327,7 @@ private[spark] class MesosSchedulerBackend( synchronized { if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - removeExecutor(taskIdToSlaveId(tid)) + removeExecutor(taskIdToSlaveId(tid), "Lost executor") } if (isFinished(status.getState)) { taskIdToSlaveId.remove(tid) @@ -359,9 +359,9 @@ private[spark] class MesosSchedulerBackend( /** * Remove executor associated with slaveId in a thread safe manner. */ - private def removeExecutor(slaveId: String) = { + private def removeExecutor(slaveId: String, reason: String) = { synchronized { - listenerBus.post(SparkListenerExecutorRemoved(slaveId)) + listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) slaveIdsWithExecutors -= slaveId } } @@ -369,7 +369,7 @@ private[spark] class MesosSchedulerBackend( private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) - removeExecutor(slaveId.getValue) + removeExecutor(slaveId.getValue, reason.toString) scheduler.executorLost(slaveId.getValue, reason) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 6f446c5a95a0a..4307029d44fbb 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,8 +24,10 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = - """Time spent deserializating the task closure on the executor.""" + val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + + val SHUFFLE_READ_BLOCKED_TIME = + "Time that the task spent blocked waiting for shuffle data to be read from remote machines." val INPUT = "Bytes read from Hadoop or from Spark storage." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 09a936c2234c0..d8be1b20b3acd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Task Deserialization Time + {if (hasShuffleRead) { +
  • + + + Shuffle Read Blocked Time + +
  • + }}
  • @@ -167,7 +176,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input", "")) else Nil} ++ {if (hasOutput) Seq(("Output", "")) else Nil} ++ - {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleRead) { + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read", "")) + } else { + Nil + }} ++ {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) else Nil} ++ @@ -271,6 +285,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val outputQuantiles = Output +: getFormattedSizeQuantiles(outputSizes) + val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + } + val shuffleReadBlockedQuantiles = Shuffle Read Blocked Time +: + getFormattedTimeQuantiles(shuffleReadBlockedTimes) + val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } @@ -308,7 +328,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {gettingResultQuantiles}, if (hasInput) {inputQuantiles} else Nil, if (hasOutput) {outputQuantiles} else Nil, - if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleRead) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadQuantiles} + } else { + Nil + }, if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) @@ -377,6 +404,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") + val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime) + val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("") + val shuffleReadBlockedTimeReadable = + maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -449,6 +481,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { }} {if (hasShuffleRead) { + + {shuffleReadBlockedTimeReadable} + {shuffleReadReadable} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 2d13bb6ddde42..37cf2c207ba40 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" + val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f896b5072e4fa..b5f736dc41c6c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -204,13 +204,16 @@ private[spark] object JsonProtocol { def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ - ("Executor ID" -> executorRemoved.executorId) + ("Timestamp" -> executorRemoved.time) ~ + ("Executor ID" -> executorRemoved.executorId) ~ + ("Removed Reason" -> executorRemoved.reason) } /** ------------------------------------------------------------------- * @@ -554,14 +557,17 @@ private[spark] object JsonProtocol { } def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] val executorInfo = executorInfoFromJson(json \ "Executor Info") - SparkListenerExecutorAdded(executorId, executorInfo) + SparkListenerExecutorAdded(time, executorId, executorInfo) } def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] - SparkListenerExecutorRemoved(executorId) + val reason = (json \ "Removed Reason").extract[String] + SparkListenerExecutorRemoved(time, executorId, reason) } /** --------------------------------------------------------------------- * diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c04e4ddfbcb7..86ac307fc84ba 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -410,10 +410,10 @@ private[spark] object Utils extends Logging { // Decompress the file if it's a .tar or .tar.gz if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xzf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xzf", fileName), targetDir) } else if (fileName.endsWith(".tar")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xf", fileName), targetDir) } // Make the file executable - That's necessary for scripts FileUtil.chmod(targetFile.getAbsolutePath, "a+x") @@ -956,25 +956,25 @@ private[spark] object Utils extends Logging { } /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. + * Execute a command and return the process running the command. */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) + def executeCommand( + command: Seq[String], + workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): Process = { + val builder = new ProcessBuilder(command: _*).directory(workingDir) + val environment = builder.environment() + for ((key, value) <- extraEnvironment) { + environment.put(key, value) + } + val process = builder.start() + if (redirectStderr) { + val threadName = "redirect stderr for command " + command(0) + def log(s: String): Unit = logInfo(s) + processStreamByLine(threadName, process.getErrorStream, log) } + process } /** @@ -983,31 +983,13 @@ private[spark] object Utils extends Logging { def executeAndGetOutput( command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines()) { - logInfo(line) - } - } - }.start() + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): String = { + val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - output.append(line) - } - } - } - stdoutThread.start() + val threadName = "read stdout for " + command(0) + def appendToOutput(s: String): Unit = output.append(s) + val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { @@ -1017,6 +999,25 @@ private[spark] object Utils extends Logging { output.toString } + /** + * Return and start a daemon thread that processes the content of the input stream line by line. + */ + def processStreamByLine( + threadName: String, + inputStream: InputStream, + processLine: String => Unit): Thread = { + val t = new Thread(threadName) { + override def run() { + for (line <- Source.fromInputStream(inputStream).getLines()) { + processLine(line) + } + } + } + t.setDaemon(true) + t.start() + t + } + /** * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the * default UncaughtExceptionHandler diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 004de05c10ee1..b16a1e9460286 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -492,6 +492,36 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @Test + public void treeReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeReduce(add, depth); + Assert.assertEquals(-5, sum); + } + } + + @Test + public void treeAggregate() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeAggregate(0, add, add, depth); + Assert.assertEquals(-5, sum); + } + } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 8a54360e81795..9bd5dfec8703a 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -28,31 +28,29 @@ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { - test("driver should exit after finishing") { + test("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" - val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + val masters = Table("master", "local", "local-cluster[2,1,512]") forAll(masters) { (master: String) => - failAfter(60 seconds) { - Utils.executeAndGetOutput( - Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - } + val process = Utils.executeCommand( + Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } } } /** - * Program that creates a Spark driver but doesn't call SparkContext.stop() or - * Sys.exit() after finishing. + * Program that creates a Spark driver but doesn't call SparkContext#stop() or + * sys.exit() after finishing. */ object DriverWithoutCleanup { def main(args: Array[String]) { Utils.configTestLog4j("INFO") - // Bind the web UI to an ephemeral port in order to avoid conflicts with other tests running on - // the same machine (we shouldn't just disable the UI here, since that might mask bugs): - val conf = new SparkConf().set("spark.ui.port", "0") + val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 7584ae79fc920..21487bc24d58a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } - ignore("two jobs sharing the same stage") { + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched - // sem2: make sure the first stage is not finished until cancel is issued + // twoJobsSharingStageSemaphore: + // make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) - val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") sc.addSparkListener(new SparkListener { @@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter // Create two actions that would share the some stages. val rdd = sc.parallelize(1 to 10, 2).map { i => - sem2.acquire() + JobCancellationSuite.twoJobsSharingStageSemaphore.acquire() (i, i) }.reduceByKey(_+_) val f1 = rdd.collectAsync() @@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter future { sem1.acquire() f1.cancel() - sem2.release(10) + JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) } - // Expect both to fail now. - // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + // Expect f1 to fail due to cancellation, intercept[SparkException] { f1.get() } - intercept[SparkException] { f2.get() } + // but f2 should not be affected + f2.get() } def testCount() { @@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter object JobCancellationSuite { val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) + val twoJobsSharingStageSemaphore = new Semaphore(0) } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 8ae4f243ec1ae..bbed8ddc6bafc 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -149,7 +149,7 @@ class SparkContextSchedulerCreationSuite } test("yarn-client") { - testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnScheduler") } def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { diff --git a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala new file mode 100644 index 0000000000000..f8c39326145e1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.scalatest.FunSuite + +import org.apache.spark.SharedSparkContext + +class SerDeUtilSuite extends FunSuite with SharedSparkContext { + + test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + SerDeUtil.pairRDDToPython(emptyRdd, 10) + } + + test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + val javaRdd = emptyRdd.toJavaRDD() + val pythonRdd = SerDeUtil.javaToPython(javaRdd) + SerDeUtil.pythonToPairRDD(pythonRdd, false) + } +} + diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index b0a70f012f1f3..af3272692d7a1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { testPackage.runCallSiteTest(sc) } + test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") { + sc = new SparkContext("local", "test") + sc.stop() + val thrown = intercept[IllegalStateException] { + sc.broadcast(Seq(1, 2, 3)) + } + assert(thrown.getMessage.toLowerCase.contains("stopped")) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { package object testPackage extends Assertions { def runCallSiteTest(sc: SparkContext) { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - val broadcast = sc.broadcast(rdd) + val broadcast = sc.broadcast(Array(1, 2, 3, 4)) broadcast.destroy() val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 065b7534cece6..82628ad3abd99 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,25 +21,28 @@ import java.io._ import scala.collection.mutable.ArrayBuffer +import org.scalatest.FunSuite +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.FunSuite -import org.scalatest.Matchers // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties { +class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { def beforeAll() { System.setProperty("spark.testing", "true") } - val noOpOutputStream = new OutputStream { + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } /** Simple PrintStream that reads data into a buffer */ - class BufferPrintStream extends PrintStream(noOpOutputStream) { + private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() override def println(line: String) { lineBuffer += line @@ -47,7 +50,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } /** Returns true if the script exits and the given search string is printed. */ - def testPrematureExit(input: Array[String], searchString: String) = { + private def testPrematureExit(input: Array[String], searchString: String) = { val printStream = new BufferPrintStream() SparkSubmit.printStream = printStream @@ -290,7 +293,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", - "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -305,7 +307,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, - "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -430,15 +431,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - def runSparkSubmit(args: Seq[String]): String = { + private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - Utils.executeAndGetOutput( + val process = Utils.executeCommand( Seq("./bin/spark-submit") ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } - def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { + private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 8379883e065e7..3fbc1a21d10ed 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers list.size should be (1) } + test("history file is renamed from inprogress to completed") { + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set("spark.testing", "true") + val provider = new FsHistoryProvider(conf) + + val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS) + writeFile(logFile1, true, None, + SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"), + SparkListenerApplicationEnd(2L) + ) + provider.checkForLogs() + val appListBeforeRename = provider.getListing() + appListBeforeRename.size should be (1) + appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS) + + logFile1.renameTo(new File(testDir, "app1")) + provider.checkForLogs() + val appListAfterRename = provider.getListing() + appListAfterRename.size should be (1) + appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS) + } + private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec], events: SparkListenerEvent*) = { val out = diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 1a28a9a187cd7..372d7aa453008 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() @@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } val conf = new MySparkConf() diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 10a39990f80ce..81db66ae17464 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -26,7 +26,16 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf, + LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter, + TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, + CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, + CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, + FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil @@ -202,7 +211,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) { + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { @@ -225,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { } } } + + test("input metrics with old CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable], + classOf[Text], 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable], + classOf[Text], new Configuration()).count() + } + assert(bytesRead >= tmpFile.length()) + } +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] { + override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter) + : OldRecordReader[LongWritable, Text] = { + new OldCombineFileRecordReader[LongWritable, Text](conf, + split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper] + .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]]) + } +} + +class OldCombineTextRecordReaderWrapper( + split: OldCombineFileSplit, + conf: Configuration, + reporter: Reporter, + idx: Integer) extends OldRecordReader[LongWritable, Text] { + + val fileSplit = new OldFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit, + conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader] + + override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value) + override def createKey(): LongWritable = delegate.createKey() + override def createValue(): Text = delegate.createValue() + override def getPos(): Long = delegate.getPos + override def close(): Unit = delegate.close() + override def getProgress(): Float = delegate.getProgress +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { + def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) + : NewRecordReader[LongWritable, Text] = { + new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + context, classOf[NewCombineTextRecordReaderWrapper]) + } } + +class NewCombineTextRecordReaderWrapper( + split: NewCombineFileSplit, + context: TaskAttemptContext, + idx: Integer) extends NewRecordReader[LongWritable, Text] { + + val fileSplit = new NewFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context) + + override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = { + delegate.initialize(fileSplit, context) + } + + override def nextKeyValue(): Boolean = delegate.nextKeyValue() + override def getCurrentKey(): LongWritable = delegate.getCurrentKey + override def getCurrentValue(): Text = delegate.getCurrentValue + override def getProgress(): Float = delegate.getProgress + override def close(): Unit = delegate.close() +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 381ee2d45630f..bede1ffb3e2d0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -927,4 +945,45 @@ class RDDSuite extends FunSuite with SharedSparkContext { mutableDependencies += dep } } + + test("nested RDDs are not supported (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator } + nestedRDD.count() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("actions cannot be performed inside of transformations (SPARK-5063)") { + val rdd: RDD[Int] = sc.parallelize(1 to 100) + val rdd2: RDD[Int] = sc.parallelize(1 to 100) + val thrown = intercept[SparkException] { + rdd.map(x => x * rdd2.count).collect() + } + assert(thrown.getMessage.contains("SPARK-5063")) + } + + test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { + val existingRDD = sc.parallelize(1 to 100) + sc.stop() + val thrown = intercept[IllegalStateException] { + existingRDD.count() + } + assert(thrown.getMessage.contains("shutdown")) + } + + test("cannot call methods on a stopped SparkContext (SPARK-5063)") { + sc.stop() + def assertFails(block: => Any): Unit = { + val thrown = intercept[IllegalStateException] { + block + } + assert(thrown.getMessage.contains("stopped")) + } + assertFails { sc.parallelize(1 to 100) } + assertFails { sc.textFile("/nonexistent-path") } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index 073814c127edc..f2ff98eb72daf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -43,7 +43,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea conf.set("spark.mesos.executor.home" , "/mesos-home") val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) + listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2))) EasyMock.replay(listenerBus) val sc = EasyMock.createMock(classOf[SparkContext]) @@ -88,7 +88,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) + listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2))) EasyMock.replay(listenerBus) val sc = EasyMock.createMock(classOf[SparkContext]) diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index dae7bf0e336de..8cf951adb354b 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite { } override def clone: SparkConf = { - new MySparkConf().setAll(settings) + new MySparkConf().setAll(getAll) } } // spark.local.dir only contains invalid directories, but that's not a problem since diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index 10541f878476c..1026cb2aa7cae 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() (1 to 100).foreach(eventLoop.post) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert((1 to 100) === buffer.toSeq) } eventLoop.stop() @@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) } eventLoop.stop() @@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts { } eventLoop.start() eventLoop.post(1) - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(e === receivedError) assert(eventLoop.isActive) } @@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts { }.start() } - eventually(timeout(5 seconds), interval(200 millis)) { + eventually(timeout(5 seconds), interval(5 millis)) { assert(threadNum * eventsFromEachThread === receivedEventsCount) } eventLoop.stop() @@ -185,4 +185,22 @@ class EventLoopSuite extends FunSuite with Timeouts { } assert(false === eventLoop.isActive) } + + test("EventLoop: stop in eventThread") { + val eventLoop = new EventLoop[Int]("test") { + + override def onReceive(event: Int): Unit = { + stop() + } + + override def onError(e: Throwable): Unit = { + } + + } + eventLoop.start() + eventLoop.post(1) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(!eventLoop.isActive) + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0357fc6ce2780..6577ebaa2e9a8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -37,6 +37,9 @@ class JsonProtocolSuite extends FunSuite { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L + val executorAddedTime = 1421458410000L + val executorRemovedTime = 1421458922000L + test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) @@ -73,9 +76,9 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) - val executorAdded = SparkListenerExecutorAdded("exec1", + val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11)) - val executorRemoved = SparkListenerExecutorRemoved("exec2") + val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -1453,9 +1456,10 @@ class JsonProtocolSuite extends FunSuite { """ private val executorAddedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorAdded", + | "Timestamp": ${executorAddedTime}, | "Executor ID": "exec1", | "Executor Info": { | "Host": "Hostee.awesome.com", @@ -1465,10 +1469,12 @@ class JsonProtocolSuite extends FunSuite { """ private val executorRemovedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorRemoved", - | "Executor ID": "exec2" + | "Timestamp": ${executorRemovedTime}, + | "Executor ID": "exec2", + | "Removed Reason": "test reason" |} """ } diff --git a/dev/check-license b/dev/check-license index 72b1013479964..a006f65710d6d 100755 --- a/dev/check-license +++ b/dev/check-license @@ -27,17 +27,17 @@ acquire_rat_jar () { if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then - curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif hash wget 2>/dev/null; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" + exit -1 + fi fi unzip -tq $JAR &> /dev/null diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b1b8cb44e098b..b2a7e092a0291 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -122,8 +122,14 @@ if [[ ! "$@" =~ --package-only ]]; then for file in $(find . -type f) do echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - gpg --print-md MD5 $file > $file.md5; - gpg --print-md SHA1 $file > $file.sha1 + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 done nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id diff --git a/docs/configuration.md b/docs/configuration.md index efbab4085317a..e4e4b8d516b75 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -197,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment + + + + + + + + + + + + + + + @@ -290,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful or it will be displayed before the driver exiting. It also can be dumped into disk by `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, they will not be displayed automatically before driver exiting. + + By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by + passing a profiler class in as a parameter to the `SparkContext` constructor. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 2094963392295..ef18cec9371d6 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -192,12 +192,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS -from numpy import array +from pyspark.mllib.recommendation import ALS, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) +ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) # Build the recommendation model using Alternating Least Squares rank = 10 @@ -205,10 +204,10 @@ numIterations = 20 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data -testdata = ratings.map(lambda p: (int(p[0]), int(p[1]))) +testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) {% endhighlight %} @@ -217,7 +216,7 @@ signals), you can use the trainImplicit method to get better results. {% highlight python %} # Build the recommendation model using Alternating Least Squares based on implicit ratings -model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) +model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01) {% endhighlight %} diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 2443fc29b4706..6486614e71354 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -886,7 +886,7 @@ for details.
    Property NameDefaultMeaning
    spark.driver.extraJavaOptions(none) + A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. +
    spark.driver.extraClassPath(none) + Extra classpath entries to append to the classpath of the driver. +
    spark.driver.extraLibraryPath(none) + Set a special library path to use when launching the driver JVM. +
    spark.executor.extraJavaOptions (none)
    groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or combineByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
    Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 0e38fe2144e9f..77c0abbbacbd0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 3bd1deaccfafe..14a87f8436984 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for standalone -clusters, Mesos clusters, or Python applications. +the drivers and the executors. Note that `cluster` mode is currently not supported for +Mesos clusters or Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 3abd3f396f605..26e7d22655694 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -20,6 +20,6 @@ # Preserve the user's CWD so that relative paths are passed correctly to #+ the underlying Python script. -SPARK_EC2_DIR="$(dirname $0)" +SPARK_EC2_DIR="$(dirname "$0")" python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/examples/pom.xml b/examples/pom.xml index 4b92147725f6b..8caad2bc2e27a 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -35,12 +35,6 @@ http://spark.apache.org/ - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -310,69 +304,40 @@ org.apache.maven.plugins maven-shade-plugin - - - package - - shade - - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar - - - *:* - - - - - com.google.guava:guava - - - ** - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - com.google - org.spark-project.guava - - com.google.common.** - - - com.google.common.base.Optional** - - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - - - - - reference.conf - - - log4j.properties - - - - - + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + org.apache.commons.math3 + org.spark-project.commons.math3 + + + + + + reference.conf + + + log4j.properties + + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 247d2a5e31a8c..0fbee6e433608 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,7 +33,7 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,11 +112,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + cvModel.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 5b92655e2e838..eaaa344be49c8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -48,13 +48,13 @@ public static void main(String[] args) { // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,14 +94,14 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - SchemaRDD results = + model2.transform(test).registerTempTable("results"); + DataFrame results = jsql.sql("SELECT features, label, probability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 74db449fada7d..82d665a3e1386 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +79,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + model.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index b70804635d5c9..8defb769ffaaf 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -74,13 +74,13 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override @@ -93,17 +93,17 @@ public String call(Row row) { } System.out.println("=== Data source: Parquet File ==="); - // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + // DataFrames can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a JavaSchemaRDD. - SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + // The result of loading a parquet file is also a DataFrame. + DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - SchemaRDD teenagers2 = + DataFrame teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override @@ -119,8 +119,8 @@ public String call(Row row) { // A JSON dataset is pointed by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; - // Create a JavaSchemaRDD from the file(s) pointed by path - SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + // Create a DataFrame from the file(s) pointed by path + DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -130,13 +130,13 @@ public String call(Row row) { // |-- age: IntegerType // |-- name: StringType - // Register this JavaSchemaRDD as a table. + // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override @@ -146,14 +146,14 @@ public String call(Row row) { System.out.println(name); } - // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // Alternatively, a DataFrame can be created for a JSON dataset represented by // a RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); - // Take a look at the schema of this new JavaSchemaRDD. + // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); // The schema of anotherPeople is ... // root @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py new file mode 100644 index 0000000000000..c7df3d7b74767 --- /dev/null +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.sql import SQLContext, Row +from pyspark.ml import Pipeline +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.classification import LogisticRegression + + +""" +A simple text classification pipeline that recognizes "spark" from +input text. This is to show how to create and configure a Spark ML +pipeline in Python. Run with: + + bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py +""" + + +if __name__ == "__main__": + sc = SparkContext(appName="SimpleTextClassificationPipeline") + sqlCtx = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row('id', 'text', 'label') + training = sqlCtx.inferSchema( + sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) + .map(lambda x: LabeledDocument(*x))) + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer() \ + .setInputCol("text") \ + .setOutputCol("words") + hashingTF = HashingTF() \ + .setInputCol(tokenizer.getOutputCol()) \ + .setOutputCol("features") + lr = LogisticRegression() \ + .setMaxIter(10) \ + .setRegParam(0.01) + pipeline = Pipeline() \ + .setStages([tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row('id', 'text') + test = sqlCtx.inferSchema( + sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) + .map(lambda x: Document(*x))) + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + prediction.registerTempTable("prediction") + selected = sqlCtx.sql("SELECT id, text, prediction from prediction") + for row in selected.collect(): + print row + + sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index 540dae785f6ea..b5a70db2b9a3c 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -16,7 +16,7 @@ # """ -An example of how to use SchemaRDD as a dataset for ML. Run with:: +An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py new file mode 100644 index 0000000000000..e647773ad9060 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient boosted Trees classification and regression using MLlib. +""" + +import sys + +from pyspark.context import SparkContext +from pyspark.mllib.tree import GradientBoostedTrees +from pyspark.mllib.util import MLUtils + + +def testClassification(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \ + / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification ensemble model:') + print(model.toDebugString()) + + +def testRegression(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \ + / float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression ensemble model:') + print(model.toDebugString()) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print >> sys.stderr, "Usage: gradient_boosted_trees" + exit(1) + sc = SparkContext(appName="PythonGradientBoostedTrees") + + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + print('\nRunning example of classification using GradientBoostedTrees\n') + testClassification(trainingData, testData) + + print('\nRunning example of regression using GradientBoostedTrees\n') + testRegression(trainingData, testData) + + sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d2c5ca48c6cb8..7f5c68e3d0fe2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -30,18 +30,18 @@ some_rdd = sc.parallelize([Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a SchemaRDD and print the schema - some_schemardd = sqlContext.inferSchema(some_rdd) - some_schemardd.printSchema() + # Infer schema from the first row, create a DataFrame and print the schema + some_df = sqlContext.inferSchema(some_rdd) + some_df.printSchema() # Another RDD is created from a list of tuples another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) # Schema with two fields - person_name and person_age schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) - # Create a SchemaRDD by applying the schema to the RDD and print the schema - another_schemardd = sqlContext.applySchema(another_rdd, schema) - another_schemardd.printSchema() + # Create a DataFrame by applying the schema to the RDD and print the schema + another_df = sqlContext.applySchema(another_rdd, schema) + another_df.printSchema() # root # |-- age: integer (nullable = true) # |-- name: string (nullable = true) @@ -49,7 +49,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - # Create a SchemaRDD from the file(s) pointed to by path + # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root # |-- person_name: string (nullable = false) @@ -61,7 +61,7 @@ # |-- age: IntegerType # |-- name: StringType - # Register this SchemaRDD as a table. + # Register this DataFrame as a table. people.registerAsTable("people") # SQL statements can be run by using the sql methods provided by sqlContext diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46d..283bb80f1c788 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator @@ -101,7 +100,7 @@ object CrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cf62772b92651..b7885829459a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -143,7 +143,7 @@ object MovieLensALS { // Evaluate the model. // TODO: Create an evaluator to compute RMSE. - val mse = predictions.select('rating, 'prediction) + val mse = predictions.select("rating", "prediction").rdd .flatMap { case Row(rating: Float, prediction: Float) => val err = rating.toDouble - prediction val err2 = err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929cb..95cc9801eaeb9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -42,7 +41,7 @@ object SimpleParamsExample { // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. val training = sparkContext.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), @@ -92,7 +91,7 @@ object SimpleParamsExample { // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + .select("features", "label", "probability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229def..065db62b0f5ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} @@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline { // Make predictions on test documents. model.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec7327..ab58375649d25 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -47,7 +47,7 @@ object DatasetExample { val defaultParams = Params() val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + head("Dataset: an example app using DataFrame as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) @@ -80,20 +80,20 @@ object DatasetExample { } println(s"Loaded ${origData.count()} instances from file: ${params.input}") - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + // Convert input data to DataFrame explicitly. + val df: DataFrame = origData.toDataFrame + println(s"Inferred schema:\n${df.schema.prettyJson}") + println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + // Select columns, using implicit conversion to DataFrames. + val labelsDf: DataFrame = origData.select("label") + val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featuresDf: DataFrame = origData.select("features") + val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) @@ -103,13 +103,13 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) + df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b80..82a0b637b3cff 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.Dsl._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -54,7 +55,7 @@ object RDDRelation { rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. rdd.saveAsParquetFile("pair.parquet") @@ -63,7 +64,7 @@ object RDDRelation { val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") diff --git a/graphx/pom.xml b/graphx/pom.xml index 72374aae6da9b..8fac24b6ed86d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -40,6 +40,10 @@ spark-core_${scala.binary.version} ${project.version} + + com.google.guava + guava + org.jblas jblas diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index 897c7ee12a436..f1550ac2e18ad 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 9da0064104fb6..ed9876b8dc21c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("non-default number of edge partitions") { + val n = 10 + val defaultParallelism = 3 + val numEdgePartitions = 4 + assert(defaultParallelism != numEdgePartitions) + val conf = new org.apache.spark.SparkConf() + .set("spark.default.parallelism", defaultParallelism.toString) + val sc = new SparkContext("local", "test", conf) + try { + val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), + numEdgePartitions) + val graph = Graph.fromEdgeTuples(edges, 1) + val neighborAttrSums = graph.mapReduceTriplets[Int]( + et => Iterator((et.dstId, et.srcAttr)), _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n))) + } finally { + sc.stop() + } + } + } diff --git a/make-distribution.sh b/make-distribution.sh index 4e2f400be3053..051c87c0894ae 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,6 +32,10 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false +TACHYON_VERSION="0.5.0" +TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" +TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" + MAKE_TGZ=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -93,7 +97,7 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found - if which rpm &>/dev/null; then + if [ $(command -v rpm) ]; then RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) if [ "$RPM_JAVA_HOME" != "%java_home" ]; then JAVA_HOME=$RPM_JAVA_HOME @@ -107,7 +111,7 @@ if [ -z "$JAVA_HOME" ]; then exit -1 fi -if which git &>/dev/null; then +if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) if [ ! -z $GITREV ]; then GITREVSTRING=" (git revision $GITREV)" @@ -115,14 +119,15 @@ if which git &>/dev/null; then unset GITREV fi -if ! which $MVN &>/dev/null; then + +if [ ! $(command -v $MVN) ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$($MVN help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SPARK_HADOOP_VERSION=$($MVN help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ @@ -171,13 +176,16 @@ cd "$SPARK_HOME" export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="$MVN clean package -DskipTests $@" +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." -echo -e "\$ $BUILD_COMMAND\n" +echo -e "\$ ${BUILD_COMMAND[@]}\n" -${BUILD_COMMAND} +"${BUILD_COMMAND[@]}" # Make directories rm -rf "$DISTDIR" @@ -222,16 +230,22 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.5.0" - TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` pushd $TMPD > /dev/null echo "Fetching tachyon tgz" - wget "$TACHYON_URL" - tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" + TACHYON_DL="${TACHYON_TGZ}.part" + if [ $(command -v curl) ]; then + curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + elif [ $(command -v wget) ]; then + wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + else + printf "You do not have curl or wget installed. please install Tachyon manually.\n" + exit -1 + fi + + tar xzf "${TACHYON_TGZ}" cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" diff --git a/mllib/pom.xml b/mllib/pom.xml index a0bda89ccaa71..fc2b2cc09c717 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -125,6 +125,9 @@ ../python pyspark/mllib/*.py + pyspark/mllib/stat/*.py + pyspark/ml/*.py + pyspark/ml/param/*.py diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 77d230eb4a122..bc3defe968afd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -21,7 +21,7 @@ import scala.annotation.varargs import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { val map = new ParamMap().put(paramPairs: _*) fit(dataset, map) } @@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMap parameter map * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMaps an array of parameter maps * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e56..d2ca2e6871e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae9..fe39cd1bc0bd2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val theStages = map(stages) @@ -162,7 +162,7 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index af56f9c435351..cd95c16aa768d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,9 +22,8 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types._ /** @@ -41,7 +40,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -53,7 +52,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -95,11 +94,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + dataset.select($"*", callUDF( + this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8316..18be35ad59452 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,8 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -87,11 +86,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setScoreCol(value: String): this.type = set(scoreCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -131,19 +129,17 @@ class LogisticRegressionModel private[ml] ( validateAndTransformSchema(schema, paramMap, fitting = false) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { + val scoreFunction: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) } val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 - } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } + dataset + .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b5719..1979ab9eb6516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** @@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params def setScoreCol(value: String): this.type = set(scoreCol, value) def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema @@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params require(labelType == DoubleType, s"Label column ${map(labelCol)} must be double type but found $labelType") - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) .map { case Row(score: Double, label: Double) => (score, label) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e02182..01a4f5eb205e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,8 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -43,14 +42,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) Params.inheritValues(map, this, model) @@ -83,14 +78,13 @@ class StandardScalerModel private[ml] ( def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 04f9cfb1bfc2f..5fb4379e23c2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -164,6 +164,13 @@ trait Params extends Identifiable with Serializable { this } + /** + * Sets a parameter (by name) in the embedded param map. + */ + private[ml] def set(param: String, value: Any): this.type = { + set(getParam(param), value) + } + /** * Gets the value of a parameter in the embedded param map. */ @@ -286,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten new ParamMap(this.map ++ other.map) } - /** * Adds all parameters from the input param map into this param map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2d89e76a4c8b2..aaad548143c4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} @@ -112,21 +110,11 @@ class ALSModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { - import dataset.sqlContext._ - import org.apache.spark.ml.recommendation.ALSModel.Factor + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + import dataset.sqlContext.createDataFrame val map = this.paramMap ++ paramMap - // TODO: Add DSL to simplify the code here. - val instanceTable = s"instance_$uid" - val userTable = s"user_$uid" - val itemTable = s"item_$uid" - val instances = dataset.as(Symbol(instanceTable)) - val users = userFactors.map { case (id, features) => - Factor(id, features) - }.as(Symbol(userTable)) - val items = itemFactors.map { case (id, features) => - Factor(id, features) - }.as(Symbol(itemTable)) + val users = userFactors.toDataFrame("id", "features") + val items = itemFactors.toDataFrame("id", "features") val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) @@ -135,13 +123,14 @@ class ALSModel private[ml] ( } } val inputColumns = dataset.schema.fieldNames - val prediction = - predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) - val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction - instances - .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) - .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) + val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) + val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") .select(outputColumns: _*) + // TODO: Just use a dataset("*") + // .select(dataset("*"), prediction) } override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { @@ -149,10 +138,6 @@ class ALSModel private[ml] ( } } -private object ALSModel { - /** Case class to convert factors to SchemaRDDs */ - private case class Factor(id: Int, features: Seq[Float]) -} /** * Alternating Least Squares (ALS) matrix factorization. @@ -209,14 +194,13 @@ class ALS extends Estimator[ALSModel] with ALSParams { setMaxIter(20) setRegParam(1.0) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { - import dataset.sqlContext._ + override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap - val ratings = - dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) - .map { row => - new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) - } + val ratings = dataset + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) + .map { row => + new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424a..5d51c51346665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setEvaluator(value: Evaluator): this.type = set(evaluator, value) def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = this.paramMap ++ paramMap val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) @@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.applySchema(training, schema).cache() val validationDataset = sqlCtx.applySchema(validation, schema).cache() @@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 430d763ef7ca7..a66d6f0cf29c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -41,10 +41,11 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.loss.Losses +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -532,6 +533,35 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib GradientBoostedTrees.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainGradientBoostedTreesModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + categoricalFeaturesInfo: JMap[Int, Int], + lossStr: String, + numIterations: Int, + learningRate: Double, + maxDepth: Int): GradientBoostedTreesModel = { + val boostingStrategy = BoostingStrategy.defaultParams(algoStr) + boostingStrategy.setLoss(Losses.fromString(lossStr)) + boostingStrategy.setNumIterations(numIterations) + boostingStrategy.setLearningRate(learningRate) + boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap + + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + GradientBoostedTrees.train(cached, boostingStrategy) + } finally { + cached.unpersist(blocking = false) + } + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3260f27513c7f..a89eea0e21be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 3c2091732f9b0..2f2c6f94e9095 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d25a7cd5b439d..a3e40200bc063 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -290,6 +290,13 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) + + if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + } + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3414daccd7ca4..34e0392f1b21a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -257,80 +257,58 @@ private[spark] object BLAS extends Serializable with Logging { /** * C := alpha * A * B + beta * C - * @param transA whether to use the transpose of matrix A (true), or A itself (false). - * @param transB whether to use the transpose of matrix B (true), or B itself (false). * @param alpha a scalar to scale the multiplication A * B. * @param A the matrix A that will be left multiplied to B. Size of m x k. * @param B the matrix B that will be left multiplied by A. Size of k x n. * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. + * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. */ def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: Matrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { + require(!C.isTransposed, + "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0) { logDebug("gemm: alpha is equal to 0. Returning C.") } else { A match { case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C) + gemm(alpha, sparse, B, beta, C) case dense: DenseMatrix => - gemm(transA, transB, alpha, dense, B, beta, C) + gemm(alpha, dense, B, beta, C) case _ => throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") } } } - /** - * C := alpha * A * B + beta * C - * - * @param alpha a scalar to scale the multiplication A * B. - * @param A the matrix A that will be left multiplied to B. Size of m x k. - * @param B the matrix B that will be left multiplied by A. Size of k x n. - * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. - */ - def gemm( - alpha: Double, - A: Matrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix): Unit = { - gemm(false, false, alpha, A, B, beta, C) - } - /** * C := alpha * A * B + beta * C * For `DenseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: DenseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols - val tAstr = if (!transA) "N" else "T" - val tBstr = if (!transB) "N" else "T" - - require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") - require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") - require(nB == C.numCols, - s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") - - nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, - beta, C.values, C.numRows) + val tAstr = if (A.isTransposed) "T" else "N" + val tBstr = if (B.isTransposed) "T" else "N" + val lda = if (!A.isTransposed) A.numRows else A.numCols + val ldb = if (!B.isTransposed) B.numRows else B.numCols + + require(A.numCols == B.numRows, + s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") + require(A.numRows == C.numRows, + s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") + require(B.numCols == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") + + nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, + B.values, ldb, beta, C.values, C.numRows) } /** @@ -338,17 +316,15 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: SparseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols + val mA: Int = A.numRows + val nB: Int = B.numCols + val kA: Int = A.numCols + val kB: Int = B.numRows require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") @@ -358,23 +334,23 @@ private[spark] object BLAS extends Serializable with Logging { val Avals = A.values val Bvals = B.values val Cvals = C.values - val Arows = if (!transA) A.rowIndices else A.colPtrs - val Acols = if (!transA) A.colPtrs else A.rowIndices + val ArowIndices = A.rowIndices + val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (transA){ + if (A.isTransposed){ var colCounterForB = 0 - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var rowCounterForA = 0 val Cstart = colCounterForB * mA val Bstart = colCounterForB * kA while (rowCounterForA < mA) { - var i = Arows(rowCounterForA) - val indEnd = Arows(rowCounterForA + 1) + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * Bvals(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -385,19 +361,19 @@ private[spark] object BLAS extends Serializable with Logging { } } else { while (colCounterForB < nB) { - var rowCounter = 0 + var rowCounterForA = 0 val Cstart = colCounterForB * mA - while (rowCounter < mA) { - var i = Arows(rowCounter) - val indEnd = Arows(rowCounter + 1) + while (rowCounterForA < mA) { + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(colCounterForB, Acols(i)) + sum += Avals(i) * B(ArowIndices(i), colCounterForB) i += 1 } - val Cindex = Cstart + rowCounter + val Cindex = Cstart + rowCounterForA Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha - rowCounter += 1 + rowCounterForA += 1 } colCounterForB += 1 } @@ -410,17 +386,17 @@ private[spark] object BLAS extends Serializable with Logging { // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of // B, and added to C. var colCounterForB = 0 // the column to be updated in C - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Bstart = colCounterForB * kB val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) val Bval = Bvals(Bstart + colCounterForA) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -432,11 +408,11 @@ private[spark] object BLAS extends Serializable with Logging { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) - val Bval = B(colCounterForB, colCounterForA) * alpha + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) + val Bval = B(colCounterForA, colCounterForB) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -449,7 +425,6 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * @param trans whether to use the transpose of matrix A (true), or A itself (false). * @param alpha a scalar to scale the multiplication A * x. * @param A the matrix A that will be left multiplied to x. Size of m x n. * @param x the vector x that will be left multiplied by A. Size of n x 1. @@ -457,65 +432,43 @@ private[spark] object BLAS extends Serializable with Logging { * @param y the resulting vector y. Size of m x 1. */ def gemv( - trans: Boolean, alpha: Double, A: Matrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - - val mA: Int = if (!trans) A.numRows else A.numCols - val nx: Int = x.size - val nA: Int = if (!trans) A.numCols else A.numRows - - require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") - require(mA == y.size, - s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + require(A.numCols == x.size, + s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") + require(A.numRows == y.size, + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { A match { case sparse: SparseMatrix => - gemv(trans, alpha, sparse, x, beta, y) + gemv(alpha, sparse, x, beta, y) case dense: DenseMatrix => - gemv(trans, alpha, dense, x, beta, y) + gemv(alpha, dense, x, beta, y) case _ => throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") } } } - /** - * y := alpha * A * x + beta * y - * - * @param alpha a scalar to scale the multiplication A * x. - * @param A the matrix A that will be left multiplied to x. Size of m x n. - * @param x the vector x that will be left multiplied by A. Size of n x 1. - * @param beta a scalar that can be used to scale vector y. - * @param y the resulting vector y. Size of m x 1. - */ - def gemv( - alpha: Double, - A: Matrix, - x: DenseVector, - beta: Double, - y: DenseVector): Unit = { - gemv(false, alpha, A, x, beta, y) - } - /** * y := alpha * A * x + beta * y * For `DenseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val tStrA = if (!trans) "N" else "T" - nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + val tStrA = if (A.isTransposed) "T" else "N" + val mA = if (!A.isTransposed) A.numRows else A.numCols + val nA = if (!A.isTransposed) A.numCols else A.numRows + nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } @@ -524,24 +477,21 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val xValues = x.values val yValues = y.values - - val mA: Int = if (!trans) A.numRows else A.numCols - val nA: Int = if (!trans) A.numCols else A.numRows + val mA: Int = A.numRows + val nA: Int = A.numCols val Avals = A.values - val Arows = if (!trans) A.rowIndices else A.colPtrs - val Acols = if (!trans) A.colPtrs else A.rowIndices + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans) { + if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { var i = Arows(rowCounter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5a7281ec6dc3c..ad7e86827b368 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -34,8 +34,17 @@ sealed trait Matrix extends Serializable { /** Number of columns. */ def numCols: Int + /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + val isTransposed: Boolean = false + /** Converts to a dense array in column major. */ - def toArray: Array[Double] + def toArray: Array[Double] = { + val newArray = new Array[Double](numRows * numCols) + foreachActive { (i, j, v) => + newArray(j * numRows + i) = v + } + newArray + } /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] @@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable { /** Get a deep copy of the matrix. */ def copy: Matrix + /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + def transpose: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ def multiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) + BLAS.gemm(1.0, this, y, 0.0, C) C } @@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable { output } - /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(true, false, 1.0, this, y, 0.0, C) - C - } - - /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { - val output = new DenseVector(new Array[Double](numCols)) - BLAS.gemv(true, 1.0, this, y, 0.0, output) - output - } - /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() @@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable { * backing array. For example, an operation such as addition or subtraction will only be * performed on the non-zero values in a `SparseMatrix`. */ private[mllib] def update(f: Double => Double): Matrix + + /** + * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering + * of the elements are not defined. + * + * @param f the function takes three parameters where the first two parameters are the row + * and column indices respectively with the type `Int`, and the final parameter is the + * corresponding value in the matrix with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Int, Double) => Unit) } /** @@ -108,13 +116,35 @@ sealed trait Matrix extends Serializable { * @param numRows number of rows * @param numCols number of columns * @param values matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in + * row major. */ -class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { +class DenseMatrix( + val numRows: Int, + val numCols: Int, + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") - override def toArray: Array[Double] = values + /** + * Column-major dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def this(numRows: Int, numCols: Int, values: Array[Double]) = + this(numRows, numCols, values, false) override def equals(o: Any) = o match { case m: DenseMatrix => @@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) case _ => false } - private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BDM[Double](numRows, numCols, values) + } else { + val breezeMatrix = new BDM[Double](numCols, numRows, values) + breezeMatrix.t + } + } private[mllib] def apply(i: Int): Double = values(i) private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) - private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + private[mllib] def index(i: Int, j: Int): Int = { + if (!isTransposed) i + numRows * j else j + numCols * i + } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { values(index(i, j)) = v @@ -148,7 +187,38 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) this } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ + override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + // outer loop over columns + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + f(i, j, values(indStart + i)) + i += 1 + } + j += 1 + } + } else { + // outer loop over rows + var i = 0 + while (i < numRows) { + var j = 0 + val indStart = i * numCols + while (j < numCols) { + f(i, j, values(indStart + j)) + j += 1 + } + i += 1 + } + } + } + + /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. */ def toSparse(): SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -157,9 +227,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) var j = 0 while (j < numCols) { var i = 0 - val indStart = j * numRows while (i < numRows) { - val v = values(indStart + i) + val v = values(index(i, j)) if (v != 0.0) { rowIndices += i spVals += v @@ -271,49 +340,73 @@ object DenseMatrix { * @param rowIndices the row index of the entry. They must be in strictly increasing order for each * column * @param values non-zero matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered + * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, + * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ class SparseMatrix( val numRows: Int, val numCols: Int, val colPtrs: Array[Int], val rowIndices: Array[Int], - val values: Array[Double]) extends Matrix { + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + - s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + - s"numCols: $numCols") + // The Or statement is for the case when the matrix is transposed + require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + + "column indices should be the number of columns + 1. Currently, colPointers.length: " + + s"${colPtrs.length}, numCols: $numCols") require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") - override def toArray: Array[Double] = { - val arr = new Array[Double](numRows * numCols) - var j = 0 - while (j < numCols) { - var i = colPtrs(j) - val indEnd = colPtrs(j + 1) - val offset = j * numRows - while (i < indEnd) { - val rowIndex = rowIndices(i) - arr(offset + rowIndex) = values(i) - i += 1 - } - j += 1 - } - arr + /** + * Column-major sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing + * order for each column + * @param values non-zero matrix entries in column major + */ + def this( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + } else { + val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices) + breezeMatrix.t + } } - private[mllib] def toBreeze: BM[Double] = - new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) - private[mllib] def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { - Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + if (!isTransposed) { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } else { + Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j) + } } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { @@ -322,7 +415,7 @@ class SparseMatrix( throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { - values(index(i, j)) = v + values(ind) = v } } @@ -341,7 +434,38 @@ class SparseMatrix( this } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ + override def transpose: Matrix = + new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + var j = 0 + while (j < numCols) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + while (idx < idxEnd) { + f(rowIndices(idx), j, values(idx)) + idx += 1 + } + j += 1 + } + } else { + var i = 0 + while (i < numRows) { + var idx = colPtrs(i) + val idxEnd = colPtrs(i + 1) + while (idx < idxEnd) { + val j = rowIndices(idx) + f(i, j, values(idx)) + idx += 1 + } + i += 1 + } + } + } + + /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. */ def toDense(): DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } @@ -557,10 +681,9 @@ object Matrices { private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { breeze match { case dm: BDM[Double] => - require(dm.majorStride == dm.rows, - "Do not support stride size different from the number of rows.") - new DenseMatrix(dm.rows, dm.cols, dm.data) + new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // There is no isTranspose flag for sparse matrices in Breeze new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( @@ -679,46 +802,28 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - val nCols = spMat.numCols - while (j < nCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i, j + startCol, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nCols = mat.numCols + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i, j + startCol, v) + cnt += 1 } - j += 1 - } - startCol += nCols - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i, j + startCol, v)) } - i += 1 } - j += 1 - } - startCol += nCols - data + startCol += nCols + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } @@ -744,14 +849,12 @@ object Matrices { require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + "don't match!") mat match { - case sparse: SparseMatrix => - hasSparse = true - case dense: DenseMatrix => + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") } numRows += mat.numRows - } if (!hasSparse) { val allValues = new Array[Double](numRows * numCols) @@ -759,61 +862,37 @@ object Matrices { matrices.foreach { mat => var j = 0 val nRows = mat.numRows - val values = mat.toArray - while (j < numCols) { - var i = 0 + mat.foreachActive { (i, j, v) => val indStart = j * numRows + startRow - val subMatStart = j * nRows - while (i < nRows) { - allValues(indStart + i) = values(subMatStart + i) - i += 1 - } - j += 1 + allValues(indStart + i) = v } startRow += nRows } new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - while (j < numCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i + startRow, j, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nRows = mat.numRows + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i + startRow, j, v) + cnt += 1 } - j += 1 - } - startRow += spMat.numRows - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startRow += nRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i + startRow, j, v)) } - i += 1 } - j += 1 - } - startRow += nRows - data + startRow += nRows + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7ee0224ad4662..8f75e6f46e05d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** @@ -77,7 +78,7 @@ sealed trait Vector extends Serializable { result = 31 * result + (bits ^ (bits >>> 32)).toInt } } - return result + result } /** @@ -110,7 +111,7 @@ sealed trait Vector extends Serializable { /** * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. + * via [[org.apache.spark.sql.DataFrame]]. */ private[spark] class VectorUDT extends UserDefinedType[Vector] { @@ -333,7 +334,7 @@ object Vectors { math.pow(sum, 1.0 / p) } } - + /** * Returns the squared distance between two Vectors. * @param v1 first Vector. @@ -341,8 +342,9 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { + require(v1.size == v2.size, "vector dimension mismatch") var squaredDistance = 0.0 - (v1, v2) match { + (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => val v1Values = v1.values val v1Indices = v1.indices @@ -350,12 +352,12 @@ object Vectors { val v2Indices = v2.indices val nnzv1 = v1Indices.size val nnzv2 = v2Indices.size - + var kv1 = 0 var kv2 = 0 while (kv1 < nnzv1 || kv2 < nnzv2) { var score = 0.0 - + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { score = v1Values(kv1) kv1 += 1 @@ -370,18 +372,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } @@ -397,7 +404,7 @@ object Vectors { val nnzv1 = indices.size val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 - + while (kv2 < nnzv2) { var score = 0.0 if (kv2 != iv1) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala new file mode 100644 index 0000000000000..693419f827379 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import scala.collection.mutable.ArrayBuffer + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.{Logging, Partitioner} +import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A grid partitioner, which uses a regular grid to partition coordinates. + * + * @param rows Number of rows. + * @param cols Number of columns. + * @param rowsPerPart Number of rows per partition, which may be less at the bottom edge. + * @param colsPerPart Number of columns per partition, which may be less at the right edge. + */ +private[mllib] class GridPartitioner( + val rows: Int, + val cols: Int, + val rowsPerPart: Int, + val colsPerPart: Int) extends Partitioner { + + require(rows > 0) + require(cols > 0) + require(rowsPerPart > 0) + require(colsPerPart > 0) + + private val rowPartitions = math.ceil(rows / rowsPerPart).toInt + private val colPartitions = math.ceil(cols / colsPerPart).toInt + + override val numPartitions = rowPartitions * colPartitions + + /** + * Returns the index of the partition the input coordinate belongs to. + * + * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in + * multiplication. k is ignored in computing partitions. + * @return The index of the partition, which the coordinate belongs to. + */ + override def getPartition(key: Any): Int = { + key match { + case (i: Int, j: Int) => + getPartitionId(i, j) + case (i: Int, j: Int, _: Int) => + getPartitionId(i, j) + case _ => + throw new IllegalArgumentException(s"Unrecognized key: $key.") + } + } + + /** Partitions sub-matrices as blocks with neighboring sub-matrices. */ + private def getPartitionId(i: Int, j: Int): Int = { + require(0 <= i && i < rows, s"Row index $i out of range [0, $rows).") + require(0 <= j && j < cols, s"Column index $j out of range [0, $cols).") + i / rowsPerPart + j / colsPerPart * rowPartitions + } + + override def equals(obj: Any): Boolean = { + obj match { + case r: GridPartitioner => + (this.rows == r.rows) && (this.cols == r.cols) && + (this.rowsPerPart == r.rowsPerPart) && (this.colsPerPart == r.colsPerPart) + case _ => + false + } + } +} + +private[mllib] object GridPartitioner { + + /** Creates a new [[GridPartitioner]] instance. */ + def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = { + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } + + /** Creates a new [[GridPartitioner]] instance with the input suggested number of partitions. */ + def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = { + require(suggestedNumPartitions > 0) + val scale = 1.0 / math.sqrt(suggestedNumPartitions) + val rowsPerPart = math.round(math.max(scale * rows, 1.0)).toInt + val colsPerPart = math.round(math.max(scale * cols, 1.0)).toInt + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } +} + +/** + * Represents a distributed matrix in blocks of local matrices. + * + * @param blocks The RDD of sub-matrix blocks (blockRowIndex, blockColIndex, sub-matrix) that form + * this distributed matrix. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero, + * the number of rows will be calculated when `numRows` is invoked. + * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to + * zero, the number of columns will be calculated when `numCols` is invoked. + */ +class BlockMatrix( + val blocks: RDD[((Int, Int), Matrix)], + val rowsPerBlock: Int, + val colsPerBlock: Int, + private var nRows: Long, + private var nCols: Long) extends DistributedMatrix with Logging { + + private type MatrixBlock = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), sub-matrix) + + /** + * Alternate constructor for BlockMatrix without the input of the number of rows and columns. + * + * @param rdd The RDD of SubMatrices (local matrices) that form this matrix + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + */ + def this( + rdd: RDD[((Int, Int), Matrix)], + rowsPerBlock: Int, + colsPerBlock: Int) = { + this(rdd, rowsPerBlock, colsPerBlock, 0L, 0L) + } + + override def numRows(): Long = { + if (nRows <= 0L) estimateDim() + nRows + } + + override def numCols(): Long = { + if (nCols <= 0L) estimateDim() + nCols + } + + val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt + + private[mllib] var partitioner: GridPartitioner = + GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size) + + /** Estimates the dimensions of the matrix. */ + private def estimateDim(): Unit = { + val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) => + (blockRowIndex.toLong * rowsPerBlock + mat.numRows, + blockColIndex.toLong * colsPerBlock + mat.numCols) + }.reduce { (x0, x1) => + (math.max(x0._1, x1._1), math.max(x0._2, x1._2)) + } + if (nRows <= 0L) nRows = rows + assert(rows <= nRows, s"The number of rows $rows is more than claimed $nRows.") + if (nCols <= 0L) nCols = cols + assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") + } + + /** Caches the underlying RDD. */ + def cache(): this.type = { + blocks.cache() + this + } + + /** Persists the underlying RDD with the specified storage level. */ + def persist(storageLevel: StorageLevel): this.type = { + blocks.persist(storageLevel) + this + } + + /** Converts to CoordinateMatrix. */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => + val rowStart = blockRowIndex.toLong * rowsPerBlock + val colStart = blockColIndex.toLong * colsPerBlock + val entryValues = new ArrayBuffer[MatrixEntry]() + mat.foreachActive { (i, j, v) => + if (v != 0.0) entryValues.append(new MatrixEntry(rowStart + i, colStart + j, v)) + } + entryValues + } + new CoordinateMatrix(entryRDD, numRows(), numCols()) + } + + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + def toIndexedRowMatrix(): IndexedRowMatrix = { + require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + + s"numCols: ${numCols()}") + // TODO: This implementation may be optimized + toCoordinateMatrix().toIndexedRowMatrix() + } + + /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + def toLocalMatrix(): Matrix = { + require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + + s"Int.MaxValue. Currently numRows: ${numRows()}") + require(numCols() < Int.MaxValue, "The number of columns of this matrix should be less than " + + s"Int.MaxValue. Currently numCols: ${numCols()}") + require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " + + s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}") + val m = numRows().toInt + val n = numCols().toInt + val mem = m * n / 125000 + if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!") + + val localBlocks = blocks.collect() + val values = new Array[Double](m * n) + localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) => + val rowOffset = blockRowIndex * rowsPerBlock + val colOffset = blockColIndex * colsPerBlock + submat.foreachActive { (i, j, v) => + val indexOffset = (j + colOffset) * m + rowOffset + i + values(indexOffset) = v + } + } + new DenseMatrix(m, n, values) + } + + /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. */ + def transpose: BlockMatrix = { + val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => + ((blockColIndex, blockRowIndex), mat.transpose) + } + new BlockMatrix(transposedBlocks, colsPerBlock, rowsPerBlock, nCols, nRows) + } + + /** Collects data and assembles a local dense breeze matrix (for test only). */ + private[mllib] def toBreeze(): BDM[Double] = { + val localMat = toLocalMatrix() + new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index b60559c853a50..078d1fac44443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -21,8 +21,7 @@ import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} /** * :: Experimental :: @@ -98,6 +97,46 @@ class CoordinateMatrix( toIndexedRowMatrix().toRowMatrix() } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + require(rowsPerBlock > 0, + s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") + require(colsPerBlock > 0, + s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock") + val m = numRows() + val n = numCols() + val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt + val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt + val partitioner = GridPartitioner(numRowBlocks, numColBlocks, entries.partitions.length) + + val blocks: RDD[((Int, Int), Matrix)] = entries.map { entry => + val blockRowIndex = (entry.i / rowsPerBlock).toInt + val blockColIndex = (entry.j / colsPerBlock).toInt + + val rowId = entry.i % rowsPerBlock + val colId = entry.j % colsPerBlock + + ((blockRowIndex, blockColIndex), (rowId.toInt, colId.toInt, entry.value)) + }.groupByKey(partitioner).map { case ((blockRowIndex, blockColIndex), entry) => + val effRows = math.min(m - blockRowIndex.toLong * rowsPerBlock, rowsPerBlock).toInt + val effCols = math.min(n - blockColIndex.toLong * colsPerBlock, colsPerBlock).toInt + ((blockRowIndex, blockColIndex), SparseMatrix.fromCOO(effRows, effCols, entry)) + } + new BlockMatrix(blocks, rowsPerBlock, colsPerBlock, m, n) + } + /** Determines the size by computing the max row/column index. */ private def computeSize() { // Reduce will throw an exception if `entries` is empty. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index c518271f04729..3be530fa07537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,24 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + // TODO: This implementation may be optimized + toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) + } + /** * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 02075edbabf85..ddca30c3c01c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,7 +30,6 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 0857877951c82..4b7d0589c973b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} -import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d16d0daf08565..d5e4f4ccbff10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 57c0768084e41..78172843be56e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -21,10 +21,7 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.HashPartitioner -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * Reduces the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#reduce]] + * @see [[org.apache.spark.rdd.RDD#treeReduce]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead. */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - val cleanF = self.context.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) - val op: (Option[T], Option[T]) => Option[T] = (c, x) => { - if (c.isDefined && x.isDefined) { - Some(cleanF(c.get, x.get)) - } else if (c.isDefined) { - c - } else if (x.isDefined) { - x - } else { - None - } - } - RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) - .getOrElse(throw new UnsupportedOperationException("empty collection")) - } + @deprecated("Use RDD.treeReduce instead.", "1.3.0") + def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth) /** * Aggregates the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] + * @see [[org.apache.spark.rdd.RDD#treeAggregate]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead. */ + @deprecated("Use RDD.treeAggregate instead.", "1.3.0") def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, depth: Int = 2): U = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (self.partitions.size == 0) { - return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = self.context.clean(seqOp) - val cleanCombOp = self.context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values - } - partiallyAggregated.reduce(cleanCombOp) + self.treeAggregate(zeroValue)(seqOp, combOp, depth) } } @@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ - implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index e9304b5e5c650..482dd4b272d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -140,6 +140,7 @@ private class RandomForest ( logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. @@ -155,19 +156,12 @@ private class RandomForest ( // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val (subsample, withReplacement) = { - // TODO: Have a stricter check for RF in the strategy - val isRandomForest = numTrees > 1 - if (isRandomForest) { - (1.0, true) - } else { - (strategy.subsamplingRate, false) - } - } + val withReplacement = if (numTrees > 1) true else false val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) - .persist(StorageLevel.MEMORY_AND_DISK) + = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, + withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 0ef9c6181a0a0..b6099259971b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -29,8 +29,8 @@ object Algo extends Enumeration { val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { - case "classification" => Classification - case "regression" => Regression + case "classification" | "Classification" => Classification + case "regression" | "Regression" => Regression case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 972959885f396..3308adb6752ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -156,6 +156,9 @@ class Strategy ( s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + require(subsamplingRate > 0 && subsamplingRate <= 1, + s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + + s"$subsamplingRate") } /** Returns a shallow copy of this instance. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0e02345aa3774..b7950e00786ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc val lbl = label.toInt require(lbl < stats.length, s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "Entropy does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7c83cd48e16a0..c946db9c0d1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } + if (label < 0) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s"but requires label is non-negative.") + } allStats(offset + label.toInt) += instanceWeight } @@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula val lbl = label.toInt require(lbl < stats.length, s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "GiniImpurity does not support negative labels") val cnt = count if (cnt == 0) { 0 diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 47f1f46c6c260..56a9dbdd58b64 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +37,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2eba83335bb58..f4ba23c44563e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -55,7 +55,7 @@ public void logisticRegression() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } @@ -67,7 +67,7 @@ public void logisticRegressionWithSetters() { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a9f1c4a2c3ca7..074b58c07df7a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae9..2f175fb117941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1d..33e40dc7410cc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -21,49 +21,43 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( + dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "prediction") .collect() } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + .select("label", "score", "prediction") .collect() } test("logistic regression fit and transform with varargs") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) + .select("label", "probability", "prediction") .collect() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index cdd4db1b5b7dc..9da253c61d36f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { val sqlContext = this.sqlContext - import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + import sqlContext.createDataFrame val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -360,7 +360,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val alpha = als.getAlpha val model = als.fit(training) val predictions = model.transform(test) - .select('rating, 'prediction) + .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b1..761ea821ef7c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() val sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( + dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 4c93c0ca4f86c..e9e510b6f5546 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -22,7 +22,6 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 771878e925ea7..b0b78acd6df16 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -169,16 +169,17 @@ class BLASSuite extends FunSuite { } test("gemm") { - val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) + val BT = B.transpose - assert(dA multiply B ~== expected absTol 1e-15) - assert(sA multiply B ~== expected absTol 1e-15) + assert(dA.multiply(B) ~== expected absTol 1e-15) + assert(sA.multiply(B) ~== expected absTol 1e-15) val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) val C2 = C1.copy @@ -188,6 +189,10 @@ class BLASSuite extends FunSuite { val C6 = C1.copy val C7 = C1.copy val C8 = C1.copy + val C9 = C1.copy + val C10 = C1.copy + val C11 = C1.copy + val C12 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) @@ -202,26 +207,40 @@ class BLASSuite extends FunSuite { withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemm(true, false, 1.0, dA, B, 2.0, C1) + gemm(1.0, dA.transpose, B, 2.0, C1) } } - val dAT = + val dATman = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) - val sAT = + val sATman = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply B ~== expected absTol 1e-15) - assert(sAT transposeMultiply B ~== expected absTol 1e-15) - - gemm(true, false, 1.0, dAT, B, 2.0, C5) - gemm(true, false, 1.0, sAT, B, 2.0, C6) - gemm(true, false, 2.0, dAT, B, 2.0, C7) - gemm(true, false, 2.0, sAT, B, 2.0, C8) + val dATT = dATman.transpose + val sATT = sATman.transpose + val BTT = BTman.transpose.asInstanceOf[DenseMatrix] + + assert(dATT.multiply(B) ~== expected absTol 1e-15) + assert(sATT.multiply(B) ~== expected absTol 1e-15) + assert(dATT.multiply(BTT) ~== expected absTol 1e-15) + assert(sATT.multiply(BTT) ~== expected absTol 1e-15) + + gemm(1.0, dATT, BTT, 2.0, C5) + gemm(1.0, sATT, BTT, 2.0, C6) + gemm(2.0, dATT, BTT, 2.0, C7) + gemm(2.0, sATT, BTT, 2.0, C8) + gemm(1.0, dA, BTT, 2.0, C9) + gemm(1.0, sA, BTT, 2.0, C10) + gemm(2.0, dA, BTT, 2.0, C11) + gemm(2.0, sA, BTT, 2.0, C12) assert(C5 ~== expected2 absTol 1e-15) assert(C6 ~== expected2 absTol 1e-15) assert(C7 ~== expected3 absTol 1e-15) assert(C8 ~== expected3 absTol 1e-15) + assert(C9 ~== expected2 absTol 1e-15) + assert(C10 ~== expected2 absTol 1e-15) + assert(C11 ~== expected3 absTol 1e-15) + assert(C12 ~== expected3 absTol 1e-15) } test("gemv") { @@ -233,17 +252,13 @@ class BLASSuite extends FunSuite { val x = new DenseVector(Array(1.0, 2.0, 3.0)) val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA multiply x ~== expected absTol 1e-15) - assert(sA multiply x ~== expected absTol 1e-15) + assert(dA.multiply(x) ~== expected absTol 1e-15) + assert(sA.multiply(x) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy - val y5 = y1.copy - val y6 = y1.copy - val y7 = y1.copy - val y8 = y1.copy val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -257,25 +272,18 @@ class BLASSuite extends FunSuite { assert(y4 ~== expected3 absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(true, 1.0, dA, x, 2.0, y1) + gemv(1.0, dA.transpose, x, 2.0, y1) } } - val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply x ~== expected absTol 1e-15) - assert(sAT transposeMultiply x ~== expected absTol 1e-15) - - gemv(true, 1.0, dAT, x, 2.0, y5) - gemv(true, 1.0, sAT, x, 2.0, y6) - gemv(true, 2.0, dAT, x, 2.0, y7) - gemv(true, 2.0, sAT, x, 2.0, y8) - assert(y5 ~== expected2 absTol 1e-15) - assert(y6 ~== expected2 absTol 1e-15) - assert(y7 ~== expected3 absTol 1e-15) - assert(y8 ~== expected3 absTol 1e-15) + val dATT = dAT.transpose + val sATT = sAT.transpose + + assert(dATT.multiply(x) ~== expected absTol 1e-15) + assert(sATT.multiply(x) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 73a6d3a27d868..2031032373971 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + // transposed matrix + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { @@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(!matTransposed.values.eq(breeze.data), "has to copy data") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a35d0fe389fdd..b1ebfde0e5e57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -22,6 +22,9 @@ import java.util.Random import org.mockito.Mockito.when import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.mllib.util.TestingUtils._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite { assert(mat.numRows === m) assert(mat.numCols === n) assert(mat.values.eq(values), "should not copy data") - assert(mat.toArray.eq(values), "toArray should not copy data") } test("dense matrix construction with wrong dimension") { @@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } + test("transpose") { + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val dAT = dA.transpose.asInstanceOf[DenseMatrix] + val sAT = sA.transpose.asInstanceOf[SparseMatrix] + val dATexpected = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sATexpected = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT.toBreeze === dATexpected.toBreeze) + assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dA(1, 0) === dAT(0, 1)) + assert(dA(2, 1) === dAT(1, 2)) + assert(sA(1, 0) === sAT(0, 1)) + assert(sA(2, 1) === sAT(1, 2)) + + assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") + assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") + + assert(dAT.toSparse().toBreeze === sATexpected.toBreeze) + assert(sAT.toDense().toBreeze === dATexpected.toBreeze) + } + + test("foreachActive") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val dn = new DenseMatrix(m, n, allValues) + + val dnMap = MutableMap[(Int, Int), Double]() + dn.foreachActive { (i, j, value) => + dnMap.put((i, j), value) + } + assert(dnMap.size === 6) + assert(dnMap(0, 0) === 1.0) + assert(dnMap(1, 0) === 2.0) + assert(dnMap(2, 0) === 0.0) + assert(dnMap(0, 1) === 0.0) + assert(dnMap(1, 1) === 4.0) + assert(dnMap(2, 1) === 5.0) + + val spMap = MutableMap[(Int, Int), Double]() + sp.foreachActive { (i, j, value) => + spMap.put((i, j), value) + } + assert(spMap.size === 4) + assert(spMap(0, 0) === 1.0) + assert(spMap(1, 0) === 2.0) + assert(spMap(1, 1) === 4.0) + assert(spMap(2, 1) === 5.0) + } + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 @@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite { val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(0, 1, 1, 2) + // transposed versions + val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0) + val colPtrsT = Array(0, 1, 3, 4) + val rowIndicesT = Array(0, 0, 1, 1) val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) + val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values) + val deMat1T = new DenseMatrix(n, m, allValuesT) + + // should equal spMat1 & deMat1 respectively + val spMat1TT = spMat1T.transpose + val deMat1TT = deMat1T.transpose + val deMat2 = Matrices.eye(3) val spMat2 = Matrices.speye(3) val deMat3 = Matrices.eye(2) @@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) assert(deHorz1.numRows === 3) @@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite { assert(deHorz2.numCols === 0) assert(deHorz2.toArray.length === 0) - assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) - assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spHorz2 ~== spHorz3 absTol 1e-15) assert(spHorz(0, 0) === 1.0) assert(spHorz(2, 1) === 5.0) assert(spHorz(0, 2) === 1.0) @@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(2, 4) === 1.0) assert(deHorz1(1, 4) === 0.0) + // containing transposed matrices + val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2)) + val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2)) + val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2)) + val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2)) + + assert(deHorz1T ~== deHorz1 absTol 1e-15) + assert(spHorzT ~== spHorz absTol 1e-15) + assert(spHorz2T ~== spHorz2 absTol 1e-15) + assert(spHorz3T ~== spHorz3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.horzcat(Array(spMat1, spMat3)) } @@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite { assert(deVert2.numCols === 0) assert(deVert2.toArray.length === 0) - assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) - assert(spVert2.toBreeze === spVert3.toBreeze) + assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spVert2 ~== spVert3 absTol 1e-15) assert(spVert(0, 0) === 1.0) assert(spVert(2, 1) === 5.0) assert(spVert(3, 0) === 1.0) @@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite { assert(deVert1(3, 1) === 0.0) assert(deVert1(4, 1) === 1.0) + // containing transposed matrices + val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3)) + val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3)) + val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3)) + val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3)) + + assert(deVert1T ~== deVert1 absTol 1e-15) + assert(spVertT ~== spVert absTol 1e-15) + assert(spVert2T ~== spVert2 absTol 1e-15) + assert(spVert3T ~== spVert3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.vertcat(Array(spMat1, spMat2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala new file mode 100644 index 0000000000000..03f34308dd09b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import scala.util.Random + +import breeze.linalg.{DenseMatrix => BDM} +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { + + val m = 5 + val n = 4 + val rowPerPart = 2 + val colPerPart = 2 + val numPartitions = 3 + var gridBasedMat: BlockMatrix = _ + + override def beforeAll() { + super.beforeAll() + + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + + gridBasedMat = new BlockMatrix(sc.parallelize(blocks, numPartitions), rowPerPart, colPerPart) + } + + test("size") { + assert(gridBasedMat.numRows() === m) + assert(gridBasedMat.numCols() === n) + } + + test("grid partitioner") { + val random = new Random() + // This should generate a 4x4 grid of 1x2 blocks. + val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + val expected0 = Array( + Array(0, 0, 4, 4, 8, 8, 12), + Array(1, 1, 5, 5, 9, 9, 13), + Array(2, 2, 6, 6, 10, 10, 14), + Array(3, 3, 7, 7, 11, 11, 15)) + for (i <- 0 until 4; j <- 0 until 7) { + assert(part0.getPartition((i, j)) === expected0(i)(j)) + assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((-1, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((4, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, -1)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, 7)) + } + + val part1 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + val expected1 = Array( + Array(0, 2), + Array(1, 3)) + for (i <- 0 until 2; j <- 0 until 2) { + assert(part1.getPartition((i, j)) === expected1(i)(j)) + assert(part1.getPartition((i, j, random.nextInt())) === expected1(i)(j)) + } + + val part2 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + assert(part0 !== part2) + assert(part1 === part2) + + val part3 = new GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + val expected3 = Array( + Array(0, 0, 2), + Array(1, 1, 3)) + for (i <- 0 until 2; j <- 0 until 3) { + assert(part3.getPartition((i, j)) === expected3(i)(j)) + assert(part3.getPartition((i, j, random.nextInt())) === expected3(i)(j)) + } + + val part4 = GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + assert(part3 === part4) + + intercept[IllegalArgumentException] { + new GridPartitioner(2, 2, rowsPerPart = 0, colsPerPart = 1) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, rowsPerPart = 1, colsPerPart = 0) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, suggestedNumPartitions = 0) + } + } + + test("toCoordinateMatrix") { + val coordMat = gridBasedMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toIndexedRowMatrix") { + val rowMat = gridBasedMat.toIndexedRowMatrix() + assert(rowMat.numRows() === m) + assert(rowMat.numCols() === n) + assert(rowMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toBreeze and toLocalMatrix") { + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val dense = Matrices.fromBreeze(expected).asInstanceOf[DenseMatrix] + assert(gridBasedMat.toLocalMatrix() === dense) + assert(gridBasedMat.toBreeze() === expected) + } + + test("transpose") { + val expected = BDM( + (1.0, 0.0, 3.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 0.0, 1.0, 5.0)) + + val AT = gridBasedMat.transpose + assert(AT.numRows() === gridBasedMat.numCols()) + assert(AT.numCols() === gridBasedMat.numRows()) + assert(AT.toBreeze() === expected) + + // partitioner must update as well + val originalPartitioner = gridBasedMat.partitioner + val ATpartitioner = AT.partitioner + assert(originalPartitioner.colsPerPart === ATpartitioner.rowsPerPart) + assert(originalPartitioner.rowsPerPart === ATpartitioner.colsPerPart) + assert(originalPartitioner.cols === ATpartitioner.rows) + assert(originalPartitioner.rows === ATpartitioner.cols) + + // make sure it works when matrices are cached as well + gridBasedMat.cache() + val AT2 = gridBasedMat.transpose + AT2.cache() + assert(AT2.toBreeze() === AT.toBreeze()) + val A = AT2.transpose + assert(A.toBreeze() === gridBasedMat.toBreeze()) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 80bef814ce50d..04b36a9ef9990 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -100,4 +100,18 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(0.0, 9.0, 0.0, 0.0)) assert(rows === expected) } + + test("toBlockMatrix") { + val blockMat = mat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === mat.toBreeze()) + + intercept[IllegalArgumentException] { + mat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + mat.toBlockMatrix(2, 0) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index b86c2ca5ff136..2ab53cc13db71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -88,6 +88,21 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(coordMat.toBreeze() === idxRowMat.toBreeze()) } + test("toBlockMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val blockMat = idxRowMat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === idxRowMat.toBreeze()) + + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(2, 0) + } + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 681ce9263933b..6d6c0aa5be812 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } - - test("treeAggregate") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 - for (depth <- 1 until 10) { - val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) - } - } - - test("treeReduce") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - for (depth <- 1 until 10) { - val sum = rdd.treeReduce(_ + _, depth) - assert(sum === -1000) - } - } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 3aa97e544680b..e8341a5d0d104 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -128,6 +128,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } + test("SPARK-5496: BoostingStrategy.defaultParams should recognize Classification") { + for (algo <- Seq("classification", "Classification", "regression", "Regression")) { + BoostingStrategy.defaultParams(algo) + } + } } object GradientBoostedTreesSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala new file mode 100644 index 0000000000000..92b498580af03 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + */ +class ImpuritySuite extends FunSuite with MLlibTestSparkContext { + test("Gini impurity does not support negative labels") { + val gini = new GiniAggregator(2) + intercept[IllegalArgumentException] { + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } + + test("Entropy does not support negative labels") { + val entropy = new EntropyAggregator(2) + intercept[IllegalArgumentException] { + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c125..55e963977b54f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } + + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + } diff --git a/network/common/pom.xml b/network/common/pom.xml index 245a96b8c4038..5a9bbe105d9f1 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,10 +48,15 @@ slf4j-api provided + com.google.guava guava - provided + compile @@ -87,11 +92,6 @@ maven-jar-plugin 2.2 - - - test-jar - - test-jar-on-test-compile test-compile @@ -101,6 +101,18 @@ + + org.apache.maven.plugins + maven-shade-plugin + + false + + + com.google.guava:guava + + + + diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 5bfa1ac9c373e..c2d0300ecd904 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -52,7 +52,6 @@ com.google.guava guava - provided diff --git a/pom.xml b/pom.xml index b993391b15042..4adfdf3eb8702 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 2.0.1 0.21.0 shaded-protobuf - 1.7.5 + 1.7.10 1.2.17 1.0.4 2.4.1 @@ -1264,7 +1264,10 @@ - + org.apache.maven.plugins maven-shade-plugin @@ -1276,6 +1279,23 @@ org.spark-project.spark:unused + + + com.google.common + org.spark-project.guava + + + com/google/common/base/Absent* + com/google/common/base/Function + com/google/common/base/Optional* + com/google/common/base/Present* + com/google/common/base/Supplier + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bc5d81f12d746..14ba03ed4634b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,6 +52,20 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( @@ -81,11 +95,30 @@ object MimaExcludes { ) ++ Seq( // SPARK-5166 Spark SQL API stabilization ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") ) ++ Seq( // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") ) ++ Seq( // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( diff --git a/project/build.properties b/project/build.properties index 32a3aeefaf9fb..064ec843da9ea 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.6 +sbt.version=0.13.7 diff --git a/python/docs/conf.py b/python/docs/conf.py index e58d97ae6a746..b00dce95d65b4 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.2-SNAPSHOT' +version = '1.3-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '1.2-SNAPSHOT' +release = '1.3-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/python/docs/index.rst b/python/docs/index.rst index 703bef644de28..d150de9d5c502 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -14,6 +14,7 @@ Contents: pyspark pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst new file mode 100644 index 0000000000000..f10d1339a9a8f --- /dev/null +++ b/python/docs/pyspark.ml.rst @@ -0,0 +1,29 @@ +pyspark.ml package +===================== + +Submodules +---------- + +pyspark.ml module +----------------- + +.. automodule:: pyspark.ml + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.feature module +------------------------- + +.. automodule:: pyspark.ml.feature + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.classification module +-------------------------------- + +.. automodule:: pyspark.ml.classification + :members: + :undoc-members: + :inherited-members: diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index e81be3b6cb796..0df12c49ad033 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -9,6 +9,7 @@ Subpackages pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib Contents diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9556e4718e585..d3efcdf221d82 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -45,6 +45,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.profiler import Profiler, BasicProfiler # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row @@ -52,4 +53,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", + "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf2b6..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,21 +215,6 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 64f6a3ca6bf4c..c0dec16ac1b25 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.profiler import ProfilerCollector, BasicProfiler from py4j.java_collections import ListConverter @@ -66,7 +66,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None): + gateway=None, jsc=None, profiler_cls=BasicProfiler): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -88,6 +88,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. + :param jsc: The JavaSparkContext instance (optional). + :param profiler_cls: A class of custom Profiler used to do profiling + (default is pyspark.profiler.BasicProfiler). >>> from pyspark.context import SparkContext @@ -102,14 +105,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc) + conf, jsc, profiler_cls) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc): + conf, jsc, profiler_cls): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -192,7 +195,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() # profiling stats collected for each PythonRDD - self._profile_stats = [] + if self._conf.get("spark.python.profile", "false") == "true": + dump_path = self._conf.get("spark.python.profile.dump", None) + self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + else: + self.profiler_collector = None def _initialize_context(self, jconf): """ @@ -229,6 +236,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): else: SparkContext._active_spark_context = instance + def __getnewargs__(self): + # This method is called when attempting to pickle SparkContext, which is always an error: + raise Exception( + "It appears that you are attempting to reference SparkContext from a broadcast " + "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "not in code that it run on workers. For more information, see SPARK-5063." + ) + def __enter__(self): """ Enable 'with SparkContext(...) as sc: app(sc)' syntax. @@ -818,39 +833,14 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - def show_profiles(self): """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("time", "cumulative").print_stats() - # mark it as showed - self._profile_stats[i][2] = True + self.profiler_collector.show_profiles() def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] + self.profiler_collector.dump_profiles(path) def _test(): diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78e..a0a028446d5fd 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -111,10 +111,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py new file mode 100644 index 0000000000000..47fed80f42e13 --- /dev/null +++ b/python/pyspark/ml/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.param import * +from pyspark.ml.pipeline import * + +__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py new file mode 100644 index 0000000000000..6bd2aa8e47837 --- /dev/null +++ b/python/pyspark/ml/classification.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import inherit_doc +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ + HasRegParam + + +__all__ = ['LogisticRegression', 'LogisticRegressionModel'] + + +@inherit_doc +class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + HasRegParam): + """ + Logistic regression. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \ + Row(label=1.0, features=Vectors.dense(1.0)), \ + Row(label=0.0, features=Vectors.sparse(1, [], []))])) + >>> lr = LogisticRegression() \ + .setMaxIter(5) \ + .setRegParam(0.01) + >>> model = lr.fit(dataset) + >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) + >>> print model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) + >>> print model.transform(test1).head().prediction + 1.0 + """ + _java_class = "org.apache.spark.ml.classification.LogisticRegression" + + def _create_model(self, java_model): + return LogisticRegressionModel(java_model) + + +class LogisticRegressionModel(JavaModel): + """ + Model fitted by LogisticRegression. + """ + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py new file mode 100644 index 0000000000000..e088acd0ca82d --- /dev/null +++ b/python/pyspark/ml/feature.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures +from pyspark.ml.util import inherit_doc +from pyspark.ml.wrapper import JavaTransformer + +__all__ = ['Tokenizer', 'HashingTF'] + + +@inherit_doc +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + A tokenizer that converts the input string to lowercase and then + splits it by white spaces. + + >>> from pyspark.sql import Row + >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")])) + >>> tokenizer = Tokenizer() \ + .setInputCol("text") \ + .setOutputCol("words") + >>> print tokenizer.transform(dataset).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + """ + + _java_class = "org.apache.spark.ml.feature.Tokenizer" + + +@inherit_doc +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): + """ + Maps a sequence of terms to their term frequencies using the + hashing trick. + + >>> from pyspark.sql import Row + >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])])) + >>> hashingTF = HashingTF() \ + .setNumFeatures(10) \ + .setInputCol("words") \ + .setOutputCol("features") + >>> print hashingTF.transform(dataset).head().features + (10,[7,8,9],[1.0,1.0,1.0]) + >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} + >>> print hashingTF.transform(dataset, params).head().vector + (5,[2,3,4],[1.0,1.0,1.0]) + """ + + _java_class = "org.apache.spark.ml.feature.HashingTF" + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py new file mode 100644 index 0000000000000..5566792cead48 --- /dev/null +++ b/python/pyspark/ml/param/__init__.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark.ml.util import Identifiable + + +__all__ = ['Param', 'Params'] + + +class Param(object): + """ + A param with self-contained documentation and optionally default value. + """ + + def __init__(self, parent, name, doc, defaultValue=None): + if not isinstance(parent, Identifiable): + raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + self.parent = parent + self.name = str(name) + self.doc = str(doc) + self.defaultValue = defaultValue + + def __str__(self): + return str(self.parent) + "-" + self.name + + def __repr__(self): + return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ + (self.parent, self.name, self.doc, self.defaultValue) + + +class Params(Identifiable): + """ + Components that take parameters. This also provides an internal + param map to store parameter values attached to the instance. + """ + + __metaclass__ = ABCMeta + + def __init__(self): + super(Params, self).__init__() + #: embedded param map + self.paramMap = {} + + @property + def params(self): + """ + Returns all params. The default implementation uses + :py:func:`dir` to get all attributes of type + :py:class:`Param`. + """ + return filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"]) + + def _merge_params(self, params): + paramMap = self.paramMap.copy() + paramMap.update(params) + return paramMap + + @staticmethod + def _dummy(): + """ + Returns a dummy Params instance used as a placeholder to generate docs. + """ + dummy = Params() + dummy.uid = "undefined" + return dummy diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py new file mode 100644 index 0000000000000..5eb81106f116c --- /dev/null +++ b/python/pyspark/ml/param/_gen_shared_params.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +header = """# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#""" + + +def _gen_param_code(name, doc, defaultValue): + """ + Generates Python code for a shared param class. + + :param name: param name + :param doc: param doc + :param defaultValue: string representation of the param + :return: code string + """ + # TODO: How to correctly inherit instance attributes? + template = '''class Has$Name(Params): + """ + Params with $name. + """ + + # a placeholder to make it appear in the generated doc + $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + + def __init__(self): + super(Has$Name, self).__init__() + #: param for $doc + self.$name = Param(self, "$name", "$doc", $defaultValue) + + def set$Name(self, value): + """ + Sets the value of :py:attr:`$name`. + """ + self.paramMap[self.$name] = value + return self + + def get$Name(self): + """ + Gets the value of $name or its default value. + """ + if self.$name in self.paramMap: + return self.paramMap[self.$name] + else: + return self.$name.defaultValue''' + + upperCamelName = name[0].upper() + name[1:] + return template \ + .replace("$name", name) \ + .replace("$Name", upperCamelName) \ + .replace("$doc", doc) \ + .replace("$defaultValue", defaultValue) + +if __name__ == "__main__": + print header + print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "from pyspark.ml.param import Param, Params\n\n" + shared = [ + ("maxIter", "max number of iterations", "100"), + ("regParam", "regularization constant", "0.1"), + ("featuresCol", "features column name", "'features'"), + ("labelCol", "label column name", "'label'"), + ("predictionCol", "prediction column name", "'prediction'"), + ("inputCol", "input column name", "'input'"), + ("outputCol", "output column name", "'output'"), + ("numFeatures", "number of features", "1 << 18")] + code = [] + for name, doc, defaultValue in shared: + code.append(_gen_param_code(name, doc, defaultValue)) + print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py new file mode 100644 index 0000000000000..586822f2de423 --- /dev/null +++ b/python/pyspark/ml/param/shared.py @@ -0,0 +1,260 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DO NOT MODIFY. The code is generated by _gen_shared_params.py. + +from pyspark.ml.param import Param, Params + + +class HasMaxIter(Params): + """ + Params with maxIter. + """ + + # a placeholder to make it appear in the generated doc + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + + def __init__(self): + super(HasMaxIter, self).__init__() + #: param for max number of iterations + self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + self.paramMap[self.maxIter] = value + return self + + def getMaxIter(self): + """ + Gets the value of maxIter or its default value. + """ + if self.maxIter in self.paramMap: + return self.paramMap[self.maxIter] + else: + return self.maxIter.defaultValue + + +class HasRegParam(Params): + """ + Params with regParam. + """ + + # a placeholder to make it appear in the generated doc + regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + + def __init__(self): + super(HasRegParam, self).__init__() + #: param for regularization constant + self.regParam = Param(self, "regParam", "regularization constant", 0.1) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + self.paramMap[self.regParam] = value + return self + + def getRegParam(self): + """ + Gets the value of regParam or its default value. + """ + if self.regParam in self.paramMap: + return self.paramMap[self.regParam] + else: + return self.regParam.defaultValue + + +class HasFeaturesCol(Params): + """ + Params with featuresCol. + """ + + # a placeholder to make it appear in the generated doc + featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + + def __init__(self): + super(HasFeaturesCol, self).__init__() + #: param for features column name + self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + self.paramMap[self.featuresCol] = value + return self + + def getFeaturesCol(self): + """ + Gets the value of featuresCol or its default value. + """ + if self.featuresCol in self.paramMap: + return self.paramMap[self.featuresCol] + else: + return self.featuresCol.defaultValue + + +class HasLabelCol(Params): + """ + Params with labelCol. + """ + + # a placeholder to make it appear in the generated doc + labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + + def __init__(self): + super(HasLabelCol, self).__init__() + #: param for label column name + self.labelCol = Param(self, "labelCol", "label column name", 'label') + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + self.paramMap[self.labelCol] = value + return self + + def getLabelCol(self): + """ + Gets the value of labelCol or its default value. + """ + if self.labelCol in self.paramMap: + return self.paramMap[self.labelCol] + else: + return self.labelCol.defaultValue + + +class HasPredictionCol(Params): + """ + Params with predictionCol. + """ + + # a placeholder to make it appear in the generated doc + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + + def __init__(self): + super(HasPredictionCol, self).__init__() + #: param for prediction column name + self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + self.paramMap[self.predictionCol] = value + return self + + def getPredictionCol(self): + """ + Gets the value of predictionCol or its default value. + """ + if self.predictionCol in self.paramMap: + return self.paramMap[self.predictionCol] + else: + return self.predictionCol.defaultValue + + +class HasInputCol(Params): + """ + Params with inputCol. + """ + + # a placeholder to make it appear in the generated doc + inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + + def __init__(self): + super(HasInputCol, self).__init__() + #: param for input column name + self.inputCol = Param(self, "inputCol", "input column name", 'input') + + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + self.paramMap[self.inputCol] = value + return self + + def getInputCol(self): + """ + Gets the value of inputCol or its default value. + """ + if self.inputCol in self.paramMap: + return self.paramMap[self.inputCol] + else: + return self.inputCol.defaultValue + + +class HasOutputCol(Params): + """ + Params with outputCol. + """ + + # a placeholder to make it appear in the generated doc + outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + + def __init__(self): + super(HasOutputCol, self).__init__() + #: param for output column name + self.outputCol = Param(self, "outputCol", "output column name", 'output') + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + self.paramMap[self.outputCol] = value + return self + + def getOutputCol(self): + """ + Gets the value of outputCol or its default value. + """ + if self.outputCol in self.paramMap: + return self.paramMap[self.outputCol] + else: + return self.outputCol.defaultValue + + +class HasNumFeatures(Params): + """ + Params with numFeatures. + """ + + # a placeholder to make it appear in the generated doc + numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + + def __init__(self): + super(HasNumFeatures, self).__init__() + #: param for number of features + self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + + def setNumFeatures(self, value): + """ + Sets the value of :py:attr:`numFeatures`. + """ + self.paramMap[self.numFeatures] = value + return self + + def getNumFeatures(self): + """ + Gets the value of numFeatures or its default value. + """ + if self.numFeatures in self.paramMap: + return self.paramMap[self.numFeatures] + else: + return self.numFeatures.defaultValue diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py new file mode 100644 index 0000000000000..2d239f8c802a0 --- /dev/null +++ b/python/pyspark/ml/pipeline.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark.ml.param import Param, Params +from pyspark.ml.util import inherit_doc + + +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] + + +@inherit_doc +class Estimator(Params): + """ + Abstract class for estimators that fit models to data. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: fitted model + """ + raise NotImplementedError() + + +@inherit_doc +class Transformer(Params): + """ + Abstract class for transformers that transform one dataset into + another. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: transformed dataset + """ + raise NotImplementedError() + + +@inherit_doc +class Pipeline(Estimator): + """ + A simple pipeline, which acts as an estimator. A Pipeline consists + of a sequence of stages, each of which is either an + :py:class:`Estimator` or a :py:class:`Transformer`. When + :py:meth:`Pipeline.fit` is called, the stages are executed in + order. If a stage is an :py:class:`Estimator`, its + :py:meth:`Estimator.fit` method will be called on the input + dataset to fit a model. Then the model, which is a transformer, + will be used to transform the dataset as the input to the next + stage. If a stage is a :py:class:`Transformer`, its + :py:meth:`Transformer.transform` method will be called to produce + the dataset for the next stage. The fitted model from a + :py:class:`Pipeline` is an :py:class:`PipelineModel`, which + consists of fitted models and transformers, corresponding to the + pipeline stages. If there are no stages, the pipeline acts as an + identity transformer. + """ + + def __init__(self): + super(Pipeline, self).__init__() + #: Param for pipeline stages. + self.stages = Param(self, "stages", "pipeline stages") + + def setStages(self, value): + """ + Set pipeline stages. + :param value: a list of transformers or estimators + :return: the pipeline instance + """ + self.paramMap[self.stages] = value + return self + + def getStages(self): + """ + Get pipeline stages. + """ + if self.stages in self.paramMap: + return self.paramMap[self.stages] + + def fit(self, dataset, params={}): + paramMap = self._merge_params(params) + stages = paramMap[self.stages] + for stage in stages: + if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): + raise ValueError( + "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) + indexOfLastEstimator = -1 + for i, stage in enumerate(stages): + if isinstance(stage, Estimator): + indexOfLastEstimator = i + transformers = [] + for i, stage in enumerate(stages): + if i <= indexOfLastEstimator: + if isinstance(stage, Transformer): + transformers.append(stage) + dataset = stage.transform(dataset, paramMap) + else: # must be an Estimator + model = stage.fit(dataset, paramMap) + transformers.append(model) + if i < indexOfLastEstimator: + dataset = model.transform(dataset, paramMap) + else: + transformers.append(stage) + return PipelineModel(transformers) + + +@inherit_doc +class PipelineModel(Transformer): + """ + Represents a compiled pipeline with transformers and fitted models. + """ + + def __init__(self, transformers): + super(PipelineModel, self).__init__() + self.transformers = transformers + + def transform(self, dataset, params={}): + paramMap = self._merge_params(params) + for t in self.transformers: + dataset = t.transform(dataset, paramMap) + return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py new file mode 100644 index 0000000000000..b627c2b4e930b --- /dev/null +++ b/python/pyspark/ml/tests.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Spark ML Python APIs. +""" + +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.sql import DataFrame +from pyspark.ml.param import Param +from pyspark.ml.pipeline import Transformer, Estimator, Pipeline + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class MockTransformer(Transformer): + + def __init__(self): + super(MockTransformer, self).__init__() + self.fake = Param(self, "fake", "fake", None) + self.dataset_index = None + self.fake_param_value = None + + def transform(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + dataset.index += 1 + return dataset + + +class MockEstimator(Estimator): + + def __init__(self): + super(MockEstimator, self).__init__() + self.fake = Param(self, "fake", "fake", None) + self.dataset_index = None + self.fake_param_value = None + self.model = None + + def fit(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + model = MockModel() + self.model = model + return model + + +class MockModel(MockTransformer, Transformer): + + def __init__(self): + super(MockModel, self).__init__() + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline() \ + .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + self.assertEqual(0, estimator0.dataset_index) + self.assertEqual(0, estimator0.fake_param_value) + model0 = estimator0.model + self.assertEqual(0, model0.dataset_index) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.fake_param_value) + self.assertEqual(2, estimator2.dataset_index) + model2 = estimator2.model + self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " + "not be called during fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py new file mode 100644 index 0000000000000..b1caa84b6306a --- /dev/null +++ b/python/pyspark/ml/util.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import uuid + + +def inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + +class Identifiable(object): + """ + Object with a unique ID. + """ + + def __init__(self): + #: A unique id for the object. The default implementation + #: concatenates the class name, "-", and 8 random hex chars. + self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + + def __repr__(self): + return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py new file mode 100644 index 0000000000000..9e12ddc3d9b8f --- /dev/null +++ b/python/pyspark/ml/wrapper.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.ml.util import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") + + +@inherit_doc +class JavaWrapper(Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + + __metaclass__ = ABCMeta + + #: Fully-qualified class name of the wrapped Java component. + _java_class = None + + def _java_obj(self): + """ + Returns or creates a Java object. + """ + java_obj = _jvm() + for name in self._java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj() + + def _transfer_params_to_java(self, params, java_obj): + """ + Transforms the embedded params and additional params to the + input Java object. + :param params: additional params (overwriting embedded values) + :param java_obj: Java object to receive the params + """ + paramMap = self._merge_params(params) + for param in self.params: + if param in paramMap: + java_obj.set(param.name, paramMap[param]) + + def _empty_java_param_map(self): + """ + Returns an empty Java ParamMap reference. + """ + return _jvm().org.apache.spark.ml.param.ParamMap() + + def _create_java_param_map(self, params, java_obj): + paramMap = self._empty_java_param_map() + for param, value in params.items(): + if param.parent is self: + paramMap.put(java_obj.getParam(param.name), value) + return paramMap + + +@inherit_doc +class JavaEstimator(Estimator, JavaWrapper): + """ + Base class for :py:class:`Estimator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _create_model(self, java_model): + """ + Creates a model from the input Java model reference. + """ + return JavaModel(java_model) + + def _fit_java(self, dataset, params={}): + """ + Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: additional params (overwriting embedded values) + :return: fitted Java model + """ + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + + def fit(self, dataset, params={}): + java_model = self._fit_java(dataset, params) + return self._create_model(java_model) + + +@inherit_doc +class JavaTransformer(Transformer, JavaWrapper): + """ + Base class for :py:class:`Transformer`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def transform(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java({}, java_obj) + java_param_map = self._create_java_param_map(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf, java_param_map), + dataset.sql_ctx) + + +@inherit_doc +class JavaModel(JavaTransformer): + """ + Base class for :py:class:`Model`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def __init__(self, java_model): + super(JavaTransformer, self).__init__() + self._java_model = java_model + + def _java_obj(self): + return self._java_model diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py new file mode 100644 index 0000000000000..799d260c096b1 --- /dev/null +++ b/python/pyspark/mllib/stat/__init__.py @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib.stat._statistics import * + +__all__ = ["Statistics", "MultivariateStatisticalSummary"] diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat/_statistics.py similarity index 88% rename from python/pyspark/mllib/stat.py rename to python/pyspark/mllib/stat/_statistics.py index c8af777a8b00d..218ac148ca992 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,17 +15,14 @@ # limitations under the License. # -""" -Python package for statistical functions in MLlib. -""" - from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat.test import ChiSqTestResult -__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -53,54 +50,6 @@ def min(self): return self.call("min").toArray() -class ChiSqTestResult(JavaModelWrapper): - """ - .. note:: Experimental - - Object containing the test results for the chi-squared hypothesis test. - """ - @property - def method(self): - """ - Name of the test method - """ - return self._java_model.method() - - @property - def pValue(self): - """ - The probability of obtaining a test statistic result at least as - extreme as the one that was actually observed, assuming that the - null hypothesis is true. - """ - return self._java_model.pValue() - - @property - def degreesOfFreedom(self): - """ - Returns the degree(s) of freedom of the hypothesis test. - Return type should be Number(e.g. Int, Double) or tuples of Numbers. - """ - return self._java_model.degreesOfFreedom() - - @property - def statistic(self): - """ - Test statistic. - """ - return self._java_model.statistic() - - @property - def nullHypothesis(self): - """ - Null hypothesis of the test. - """ - return self._java_model.nullHypothesis() - - def __str__(self): - return self._java_model.toString() - - class Statistics(object): @staticmethod diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py new file mode 100644 index 0000000000000..762506e952b43 --- /dev/null +++ b/python/pyspark/mllib/stat/test.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.mllib.common import JavaModelWrapper + + +__all__ = ["ChiSqTestResult"] + + +class ChiSqTestResult(JavaModelWrapper): + """ + .. note:: Experimental + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f48e3d6dacb4b..61e0cf5d90bd0 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -169,7 +169,7 @@ def test_kmeans_deterministic(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -198,18 +198,31 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = \ - DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -238,13 +251,27 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = \ - DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + class StatTests(PySparkTestCase): # SPARK-4023 diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 66702478474dc..aae48f213246b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -24,16 +24,48 @@ from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint -__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest'] +__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', + 'RandomForest', 'GradientBoostedTrees'] -class DecisionTreeModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper): + def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return self.call("predict", x.map(_convert_to_vector)) + + else: + return self.call("predict", _convert_to_vector(x)) + + def numTrees(self): + """ + Get number of trees in ensemble. + """ + return self.call("numTrees") + + def totalNumNodes(self): + """ + Get total number of nodes, summed over all trees in the ensemble. + """ + return self.call("totalNumNodes") + + def __repr__(self): + """ Summary of model """ + return self._java_model.toString() + + def toDebugString(self): + """ Full model """ + return self._java_model.toDebugString() + +class DecisionTreeModel(JavaModelWrapper): """ - A decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + A decision tree model for classification or regression. """ def predict(self, x): """ @@ -64,12 +96,10 @@ def toDebugString(self): class DecisionTree(object): - """ - Learning algorithm for a decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a decision tree model for classification or regression. """ @classmethod @@ -186,51 +216,19 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -class RandomForestModel(JavaModelWrapper): +class RandomForestModel(TreeEnsembleModel): """ - Represents a random forest model. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Represents a random forest model. """ - def predict(self, x): - """ - Predict values for a single data point or an RDD of points using - the model trained. - """ - if isinstance(x, RDD): - return self.call("predict", x.map(_convert_to_vector)) - - else: - return self.call("predict", _convert_to_vector(x)) - - def numTrees(self): - """ - Get number of trees in forest. - """ - return self.call("numTrees") - - def totalNumNodes(self): - """ - Get total number of nodes, summed over all trees in the forest. - """ - return self.call("totalNumNodes") - - def __repr__(self): - """ Summary of model """ - return self._java_model.toString() - - def toDebugString(self): - """ Full model """ - return self._java_model.toDebugString() class RandomForest(object): """ - Learning algorithm for a random forest model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a random forest model for classification or regression. """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -383,6 +381,137 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt featureSubsetStrategy, impurity, maxDepth, maxBins, seed) +class GradientBoostedTreesModel(TreeEnsembleModel): + """ + .. note:: Experimental + + Represents a gradient-boosted tree model. + """ + + +class GradientBoostedTrees(object): + """ + .. note:: Experimental + + Learning algorithm for a gradient boosted trees model for classification or regression. + """ + + @classmethod + def _train(cls, data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + return GradientBoostedTreesModel(model) + + @classmethod + def trainClassifier(cls, data, categoricalFeaturesInfo, + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for classification. + + :param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient boosting. + Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1] + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 + leaf node; depth 1 means 1 internal node + 2 leaf nodes. + (default: 3) + :return: GradientBoostedTreesModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(0.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}) + >>> model.numTrees() + 100 + >>> model.totalNumNodes() + 300 + >>> print model, # it already has newline + TreeEnsembleModel classifier with 100 trees + >>> model.predict([2.0]) + 1.0 + >>> model.predict([0.0]) + 0.0 + >>> rdd = sc.parallelize([[2.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "classification", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for regression. + + :param data: Training dataset: RDD of LabeledPoint. Labels are + real numbers. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient boosting. + Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1] + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 + leaf node; depth 1 means 1 internal node + 2 leaf nodes. + (default: 3) + :return: GradientBoostedTreesModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {}) + >>> model.numTrees() + 100 + >>> model.totalNumNodes() + 102 + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {0: 1.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "regression", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py new file mode 100644 index 0000000000000..4408996db0790 --- /dev/null +++ b/python/pyspark/profiler.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cProfile +import pstats +import os +import atexit + +from pyspark.accumulators import AccumulatorParam + + +class ProfilerCollector(object): + """ + This class keeps track of different profilers on a per + stage basis. Also this is used to create new profilers for + the different stages. + """ + + def __init__(self, profiler_cls, dump_path=None): + self.profiler_cls = profiler_cls + self.profile_dump_path = dump_path + self.profilers = [] + + def new_profiler(self, ctx): + """ Create a new profiler using class `profiler_cls` """ + return self.profiler_cls(ctx) + + def add_profiler(self, id, profiler): + """ Add a profiler for RDD `id` """ + if not self.profilers: + if self.profile_dump_path: + atexit.register(self.dump_profiles, self.profile_dump_path) + else: + atexit.register(self.show_profiles) + + self.profilers.append([id, profiler, False]) + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` """ + for id, profiler, _ in self.profilers: + profiler.dump(id, path) + self.profilers = [] + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, profiler, showed) in enumerate(self.profilers): + if not showed and profiler: + profiler.show(id) + # mark it as showed + self.profilers[i][2] = True + + +class Profiler(object): + """ + .. note:: DeveloperApi + + PySpark supports custom profilers, this is to allow for different profilers to + be used as well as outputting to different formats than what is provided in the + BasicProfiler. + + A custom profiler has to define or inherit the following methods: + profile - will produce a system profile of some sort. + stats - return the collected stats. + dump - dumps the profiles to a path + add - adds a profile to the existing accumulated profile + + The profiler class is chosen when creating a SparkContext + + >>> from pyspark import SparkConf, SparkContext + >>> from pyspark import BasicProfiler + >>> class MyCustomProfiler(BasicProfiler): + ... def show(self, id): + ... print "My custom profiles for RDD:%s" % id + ... + >>> conf = SparkConf().set("spark.python.profile", "true") + >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) + >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.show_profiles() + My custom profiles for RDD:1 + My custom profiles for RDD:2 + >>> sc.stop() + """ + + def __init__(self, ctx): + pass + + def profile(self, func): + """ Do profiling on the function `func`""" + raise NotImplemented + + def stats(self): + """ Return the collected profiling stats (pstats.Stats)""" + raise NotImplemented + + def show(self, id): + """ Print the profile stats to stdout, id is the RDD id """ + stats = self.stats() + if stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """ Dump the profile into path, id is the RDD id """ + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + + +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + +class BasicProfiler(Profiler): + """ + BasicProfiler is the default profiler, which is implemented based on + cProfile and Accumulator + """ + def __init__(self, ctx): + Profiler.__init__(self, ctx) + # Creates a new accumulator for combining the profiles of different + # partitions of a stage + self._accumulator = ctx.accumulator(None, PStatsParam) + + def profile(self, func): + """ Runs and profiles the method to_profile passed in. A profile object is returned. """ + pr = cProfile.Profile() + pr.runcall(func) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + + # Adds a new profile to the existing accumulated value + self._accumulator.add(st) + + def stats(self): + return self._accumulator.value + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4977400ac1c05..2f8a0edfe9644 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,9 +29,8 @@ import heapq import bisect import random -from math import sqrt, log, isinf, isnan +from math import sqrt, log, isinf, isnan, pow, ceil -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -141,6 +140,17 @@ def id(self): def __repr__(self): return self._jrdd.toString() + def __getnewargs__(self): + # This method is called when attempting to pickle an RDD, which is always an error: + raise Exception( + "It appears that you are attempting to broadcast an RDD or reference an RDD from an " + "action or transformation. RDD transformations and actions can only be invoked by the " + "driver, not inside of other transformations; for example, " + "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values " + "transformation and count action cannot be performed inside of the rdd1.map " + "transformation. For more information, see SPARK-5063." + ) + @property def context(self): """ @@ -716,6 +726,43 @@ def func(iterator): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") + def treeReduce(self, f, depth=2): + """ + Reduces the elements of this RDD in a multi-level tree pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeReduce(add) + -5 + >>> rdd.treeReduce(add, 1) + -5 + >>> rdd.treeReduce(add, 2) + -5 + >>> rdd.treeReduce(add, 5) + -5 + >>> rdd.treeReduce(add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + + def op(x, y): + if x[1]: + return y + elif y[1]: + return x + else: + return f(x[0], y[0]), False + + reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth) + if reduced[1]: + raise ValueError("Cannot reduce empty RDD.") + return reduced[0] + def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all @@ -767,6 +814,58 @@ def func(iterator): return self.mapPartitions(func).fold(zeroValue, combOp) + def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + """ + Aggregates the elements of this RDD in a multi-level tree + pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeAggregate(0, add, add) + -5 + >>> rdd.treeAggregate(0, add, add, 1) + -5 + >>> rdd.treeAggregate(0, add, add, 2) + -5 + >>> rdd.treeAggregate(0, add, add, 5) + -5 + >>> rdd.treeAggregate(0, add, add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + if self.getNumPartitions() == 0: + return zeroValue + + def aggregatePartition(iterator): + acc = zeroValue + for obj in iterator: + acc = seqOp(acc, obj) + yield acc + + partiallyAggregated = self.mapPartitions(aggregatePartition) + numPartitions = partiallyAggregated.getNumPartitions() + scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2) + # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree + # aggregation. + while numPartitions > scale + numPartitions / scale: + numPartitions /= scale + curNumPartitions = numPartitions + + def mapPartition(i, iterator): + for obj in iterator: + yield (i % curNumPartitions, obj) + + partiallyAggregated = partiallyAggregated \ + .mapPartitionsWithIndex(mapPartition) \ + .reduceByKey(combOp, curNumPartitions) \ + .values() + + return partiallyAggregated.reduce(combOp) + def max(self, key=None): """ Find the maximum item in this RDD. @@ -1623,8 +1722,8 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much - better performance. + sum or average) over each key, using reduceByKey or aggregateByKey will + provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) @@ -2048,6 +2147,20 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) + def toLocalIterator(self): + """ + Return an iterator that contains all of the elements in this RDD. + The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) + >>> [x for x in rdd.toLocalIterator()] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + partitions = xrange(self.getNumPartitions()) + for partition in partitions: + rows = self.context.runJob(self, lambda x: x, [partition]) + for row in rows: + yield row + class PipelinedRDD(RDD): @@ -2107,9 +2220,13 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + + if self.ctx.profiler_collector: + profiler = self.ctx.profiler_collector.new_profiler(self.ctx) + else: + profiler = None + + command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2132,9 +2249,9 @@ def _jrdd(self): broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - if enable_profile: + if profiler: self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) + self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1990323249cf6..3f2d7ac82585f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,15 +20,19 @@ - L{SQLContext} Main entry point for SQL functionality. - - L{SchemaRDD} + - L{DataFrame} A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedDataFrame} + - L{Column} + Column is a DataFrame with a single column. - L{Row} A Row of data returned by a Spark SQL query. - L{HiveContext} Main entry point for accessing data stored in Apache Hive.. """ +import sys import itertools import decimal import datetime @@ -36,6 +40,9 @@ import warnings import json import re +import random +import os +from tempfile import NamedTemporaryFile from array import array from operator import itemgetter from itertools import imap @@ -43,6 +50,7 @@ from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter +from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer @@ -54,7 +62,8 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SchemaRDD"] class DataType(object): @@ -922,7 +931,7 @@ def _parse_schema_abstract(s): def _infer_schema_type(obj, dataType): """ - Fill the dataType with types infered from obj + Fill the dataType with types inferred from obj >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) @@ -1171,7 +1180,7 @@ def Dict(d): class Row(tuple): - """ Row in SchemaRDD """ + """ Row in DataFrame """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) __slots__ = () @@ -1198,7 +1207,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None): :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... @@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None): >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, ... x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None): ... [Row(field1=1, field2="row1"), ... Row(field1=2, field2="row2"), ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') >>> NestedRow = Row("f1", "f2") >>> nestedRdd1 = sc.parallelize([ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] >>> nestedRdd2 = sc.parallelize([ ... NestedRow([[1, 2], [2, 3]], [1, 2]), ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] >>> from collections import namedtuple @@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None): ... [CustomRow(field1=1, field2="row1"), ... CustomRow(field1=2, field2="row2"), ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") first = rdd.first() if not first: @@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema): >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import date, datetime @@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, ... x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - >>> srdd.registerTempTable("table2") + >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + @@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema): >>> abstract = "byte short float time map{} struct(b) list[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema): rdd = rdd.map(converter) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName): Temporary tables exist only during the lifetime of this instance of SQLContext. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) else: - raise ValueError("Can only register SchemaRDD as table") + raise ValueError("Can only register DataFrame as table") def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. + """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) - return SchemaRDD(jschema_rdd, self) + jdf = self._ssql_ctx.parquetFile(path) + return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a - L{SchemaRDD}. + L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) @@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) + df = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. + """Loads an RDD storing one JSON object per string as a L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): Otherwise, it samples the dataset with ratio `samplingRatio` to determine the schema. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) @@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] >>> sqlCtx.jsonRDD(sc.parallelize(['{}', @@ -1615,33 +1624,33 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. + """Return a L{DataFrame} representing the result of the given query. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. + """Returns the specified table as a L{DataFrame}. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return DataFrame(self._ssql_ctx.table(tableName), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1707,7 +1716,7 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1785,125 +1794,119 @@ def __repr__(self): return "" % ", ".join(self) -def inherit_doc(cls): - for name, func in vars(cls).items(): - # only inherit docstring for public functions - if name.startswith("_"): - continue - if not func.__doc__: - for parent in cls.__bases__: - parent_func = getattr(parent, name, None) - if parent_func and getattr(parent_func, "__doc__", None): - func.__doc__ = parent_func.__doc__ - break - return cls +class DataFrame(object): + """A collection of rows that have the same columns. -@inherit_doc -class SchemaRDD(RDD): + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: - """An RDD of L{Row} objects that has an associated schema. + people = sqlContext.parquetFile("...") - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by Spark SQL. + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: [[DataFrame]], [[Column]]. - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. + To select a column from the data frame, use the apply method:: - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ - def __init__(self, jschema_rdd, sql_ctx): + def __init__(self, jdf, sql_ctx): + self._jdf = jdf self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None + self._sc = sql_ctx and sql_ctx._sc self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property - def _jrdd(self): - """Lazy evaluation of PythonRDD object. + def rdd(self): + """Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row`s. """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd def limit(self, num): """Limit the result count to the number specified. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.limit(2).collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() + >>> df.limit(0).collect() [] """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( "SELECT * from table1") + >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] True """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. + a DataFrame using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) True """ - self._jschema_rdd.saveAsParquetFile(path) + self._jdf.saveAsParquetFile(path) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. + that was used to create this DataFrame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.registerTempTable("test") + >>> df2 = sqlCtx.sql("select * from test") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - self._jschema_rdd.registerTempTable(name) + self._jdf.registerTempTable(name) def registerAsTable(self, name): """DEPRECATED: use registerTempTable() instead""" @@ -1911,62 +1914,61 @@ def registerAsTable(self, name): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. + """Inserts the contents of this DataFrame into the specified table. Optionally overwriting any existing data. """ - self._jschema_rdd.insertInto(tableName, overwrite) + self._jdf.insertInto(tableName, overwrite) def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by + """Returns the schema of this DataFrame (represented by a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() + return _parse_datatype_json_string(self._jdf.schema().json()) def printSchema(self): """Prints out the schema in the tree format.""" - print self.schemaString() + print (self._jdf.schema().treeString()) def count(self): """Return the number of elements in this RDD. Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, + leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.count() 3L - >>> srdd.count() == srdd.map(lambda x: x).count() + >>> df.count() == df.map(lambda x: x).count() True """ - return self._jschema_rdd.count() + return self._jdf.count() def collect(self): - """Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows. Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect() [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + return [cls(r) for r in rs] def take(self, num): """Take the first num rows of the RDD. @@ -1974,130 +1976,561 @@ def take(self, num): Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.take(2) [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] """ return self.limit(num).collect() - # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. + return self.rdd.map(f) - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 + def mapPartitions(self, f, preservesPartitioning=False): """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() + Return a new RDD by applying a function to each partition. - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, - # not the PythonRDD checkpointed by the super class def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - self._jschema_rdd.cache() + self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ self.is_cached = False - self._jschema_rdd.unpersist(blocking) + self._jdf.unpersist(blocking) return self - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df = sqlCtx.inferSchema(rdd) + >>> df.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) + + @property + def dtypes(self): + """Return all column names and their data types as a list. + """ + return [(f.name, str(f.dataType)) for f in self.schema().fields] - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() + @property + def columns(self): + """ Return all column names as a list. + """ + return [f.name for f in self.schema().fields] - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() + def show(self): + raise NotImplemented - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) - return SchemaRDD(rdd, self.sql_ctx) + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() + df1.join(df2, df1.key == df2.key, "outer") + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, + `semijoin`. + """ + if joinType is None: + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + jdf = self._jdf.join(other._jdf, joinExprs) else: - rdd = self._jschema_rdd.distinct(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.join(other._jdf, joinExprs, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new [[DataFrame]] sorted by the specified column, + in ascending column. - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) + :param cols: The columns or expressions used for sorting + """ + if not cols: + raise ValueError("should sort by at least one column") + for i, c in enumerate(cols): + if isinstance(c, basestring): + cols[i] = Column(c) + jcols = [c._jc for c in cols] + jdf = self._jdf.join(*jcols) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def tail(self): + raise NotImplemented + + def __getitem__(self, item): + if isinstance(item, basestring): + return Column(self._jdf.apply(item)) + + # TODO projection + raise IndexError + + def __getattr__(self, name): + """ Return the column by given name """ + if name.startswith("__"): + raise AttributeError(name) + return Column(self._jdf.apply(name)) + + def alias(self, name): + """ Alias the current DataFrame """ + return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions.:: + + df.select() + df.select('colA', 'colB') + df.select(df.colA, df.colB + 1) + + """ + if not cols: + cols = ["*"] + if isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only intersect with another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.select(self._jdf.toColumnArray(jcols)) + return DataFrame(jdf, self.sql_ctx) - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + def filter(self, condition): + """ Filtering rows using the given condition:: - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) + df.filter(df.age > 15) + df.where(df.age > 15) + + """ + return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the [[DataFrame]] using the specified columns, + so we can run aggregation on them. See :class:`GroupedDataFrame` + for all the available aggregate functions:: + + df.groupBy(df.department).avg() + df.groupBy("department", "gender").agg({ + "salary": "avg", + "age": "max", + }) + """ + if cols and isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only subtract another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols)) + return GroupedDataFrame(jdf, self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): + def agg(self, *exprs): + """ Aggregate on the entire [[DataFrame]] without groups + (shorthand for df.groupBy.agg()):: + + df.agg({"age": "max", "salary": "avg"}) """ - Return a sampled subset of this SchemaRDD. + return self.groupBy().agg(*exprs) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. + def intersect(self, other): + """ Return a new [[DataFrame]] containing rows only in + both this frame and another frame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + This is equivalent to `INTERSECT` in SQL. """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new [[DataFrame]] containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ Return a new DataFrame by sampling a fraction of rows. """ + if seed is None: + jdf = self._jdf.sample(withReplacement, fraction) + else: + jdf = self._jdf.sample(withReplacement, fraction, seed) + return DataFrame(jdf, self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new [[DataFrame]] by adding a column. """ + return self.select('*', col.alias(colName)) + + def removeColumn(self, colName): + raise NotImplemented + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedDataFrame(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to agregate methods. + """ + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns" + jdf = self._jdf.agg(*exprs) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +SCALA_METHOD_MAPPINGS = { + '=': '$eq', + '>': '$greater', + '<': '$less', + '+': '$plus', + '-': '$minus', + '*': '$times', + '/': '$div', + '!': '$bang', + '@': '$at', + '#': '$hash', + '%': '$percent', + '^': '$up', + '&': '$amp', + '~': '$tilde', + '?': '$qmark', + '|': '$bar', + '\\': '$bslash', + ':': '$colon', +} + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.org.apache.spark.sql.Dsl.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Column(name) + + +def _scalaMethod(name): + """ Translate operators into methodName in Scala + + For example: + >>> _scalaMethod('+') + '$plus' + >>> _scalaMethod('>=') + '$greater$eq' + >>> _scalaMethod('cast') + 'cast' + """ + return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) + + +def _unary_op(name): + """ Create a method for given unary operator """ + def _(self): + return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) + return _ + + +def _bin_op(name, pass_literal_through=False): + """ Create a method for given binary operator + + Keyword arguments: + pass_literal_through -- whether to pass literal value directly through to the JVM. + """ + def _(self, other): + if isinstance(other, Column): + jc = other._jc + else: + if pass_literal_through: + jc = other + else: + jc = _create_column_from_literal(other) + return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) + return _ + + +def _reverse_op(name): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), + self._jdf, self.sql_ctx) + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by: + {{{ + // 1. Select a column out of a DataFrame + df.colName + df["colName"] + + // 2. Create from an expression + df["colName"] + 1 + }}} + """ + + def __init__(self, jc, jdf=None, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jdf, sql_ctx) + + # arithmetic operators + __neg__ = _unary_op("unary_-") + __add__ = _bin_op("+") + __sub__ = _bin_op("-") + __mul__ = _bin_op("*") + __div__ = _bin_op("/") + __mod__ = _bin_op("%") + __radd__ = _bin_op("+") + __rsub__ = _reverse_op("-") + __rmul__ = _bin_op("*") + __rdiv__ = _reverse_op("/") + __rmod__ = _reverse_op("%") + __abs__ = _unary_op("abs") + abs = _unary_op("abs") + sqrt = _unary_op("sqrt") + + # logistic operators + __eq__ = _bin_op("===") + __ne__ = _bin_op("!==") + __lt__ = _bin_op("<") + __le__ = _bin_op("<=") + __ge__ = _bin_op(">=") + __gt__ = _bin_op(">") + # `and`, `or`, `not` cannot be overloaded in Python + And = _bin_op('&&') + Or = _bin_op('||') + Not = _unary_op('unary_!') + + # bitwise operators + __and__ = _bin_op("&") + __or__ = _bin_op("|") + __invert__ = _unary_op("unary_~") + __xor__ = _bin_op("^") + # __lshift__ = _bin_op("<<") + # __rshift__ = _bin_op(">>") + __rand__ = _bin_op("&") + __ror__ = _bin_op("|") + __rxor__ = _bin_op("^") + # __rlshift__ = _reverse_op("<<") + # __rrshift__ = _reverse_op(">>") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + # __getattr__ = _bin_op("getField") + + # string methods + rlike = _bin_op("rlike", pass_literal_through=True) + like = _bin_op("like", pass_literal_through=True) + startswith = _bin_op("startsWith", pass_literal_through=True) + endswith = _bin_op("endsWith", pass_literal_through=True) + upper = _unary_op("upper") + lower = _unary_op("lower") + + def substr(self, startPos, pos): + if type(startPos) != type(pos): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + + jc = self._jc.substr(startPos, pos) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, pos._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self._jdf, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull") + isNotNull = _unary_op("isNotNull") + + # `as` is keyword + def alias(self, alias): + return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) + + def cast(self, dataType): + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + jdt = ssql_ctx.parseDataType(dataType.json()) + return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) + + +def _aggregate_func(name): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol) + return Column(jc) + return staticmethod(_) + + +class Aggregator(object): + """ + A collections of builtin aggregators + """ + max = _aggregate_func("max") + min = _aggregate_func("min") + avg = mean = _aggregate_func("mean") + sum = _aggregate_func("sum") + first = _aggregate_func("first") + last = _aggregate_func("last") + count = _aggregate_func("count") def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..bec1961f26393 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,6 +23,7 @@ from fileinput import input from glob import glob import os +import pydoc import re import shutil import subprocess @@ -53,6 +54,7 @@ from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType from pyspark import shuffle +from pyspark.profiler import BasicProfiler _have_scipy = False _have_numpy = False @@ -714,6 +716,25 @@ def test_sample(self): wr_s21 = rdd.sample(True, 0.4, 21).collect() self.assertNotEqual(set(wr_s11), set(wr_s21)) + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + from pyspark.rdd import RDD + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + class ProfilerTests(PySparkTestCase): @@ -724,16 +745,12 @@ def setUp(self): self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): + self.do_computation() - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() self.assertTrue(stats is not None) width, stat_list = stats.get_print_list([]) func_names = [func_name for fname, n, func_name in stat_list] @@ -744,6 +761,31 @@ def heavy_foo(x): self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + class ExamplePointUDT(UserDefinedType): """ @@ -806,6 +848,9 @@ def tearDownClass(cls): def setUp(self): self.sqlCtx = SQLContext(self.sc) + self.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(self.testData) + self.df = self.sqlCtx.inferSchema(rdd) def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) @@ -821,7 +866,7 @@ def test_udf2(self): def test_udf_with_array_type(self): d = [Row(l=range(3), d={"key": range(5)})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.inferSchema(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() @@ -839,68 +884,51 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() - - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() + df = self.sqlCtx.inferSchema(rdd) + row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) - l = srdd.map(lambda x: x.l).first() + l = df.map(lambda x: x.l).first() self.assertEqual(1, len(l)) self.assertEqual('s', l[0].b) - d = srdd.map(lambda x: x.d).first() + d = df.map(lambda x: x.d).first() self.assertEqual(1, len(d)) self.assertEqual(1.0, d["key"].c) - row = srdd.map(lambda x: x.d["key"]).first() + row = df.map(lambda x: x.d["key"]).first() self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) @@ -908,26 +936,26 @@ def test_infer_schema(self): d = [Row(l=[], d={}), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -935,9 +963,9 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -945,12 +973,12 @@ def test_infer_schema_with_udt(self): from pyspark.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -959,21 +987,61 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): from pyspark.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) + df0 = self.sqlCtx.inferSchema(rdd) output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_column_operators(self): + from pyspark.sql import Column, LongType + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbit)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + # TODO(davies): fix aggregators + from pyspark.sql import Aggregator as Agg + # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7e5343c973dc5..8a93c320ec5d3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,19 +88,15 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, profiler, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) + if profiler: + profiler.profile(process) else: process() except Exception: diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6b26..e91f1a875d356 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,6 +57,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } @@ -75,12 +76,19 @@ function run_mllib_tests() { run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/stat/_statistics.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" run_test "pyspark/mllib/tests.py" } +function run_ml_tests() { + echo "Run ml tests ..." + run_test "pyspark/ml/feature.py" + run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tests.py" +} + function run_streaming_tests() { echo "Run streaming tests ..." run_test "pyspark/streaming/util.py" @@ -102,6 +110,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_ml_tests run_streaming_tests # Try to test with PyPy diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 91c9c52c3c98a..e594ad868ea1c 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -255,14 +255,14 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + test("SPARK-2576 importing SQLContext.createDataFrame.") { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,512]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createSchemaRDD + |import sqlContext.createDataFrame |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) diff --git a/sql/README.md b/sql/README.md index d058a6b011d37..61a20916a92aa 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,7 +44,7 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.SchemaRDD = +query: org.apache.spark.sql.DataFrame = == Query Plan == == Physical Plan == HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191d16fb10b5f..4def65b01f583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -57,6 +57,7 @@ trait ScalaReflection { case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (s: Array[_], arrayType: ArrayType) => s.toSeq case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } @@ -140,7 +141,9 @@ trait ScalaReflection { // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index eaadbe9fd5099..24a65f8f4d379 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -348,7 +348,7 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val baseExpression: Parser[Expression] = - ( "*" ^^^ Star(None) + ( "*" ^^^ UnresolvedStar(None) | primary ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7f4cc234dc9cd..cefd70acf3931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -250,6 +250,12 @@ class Analyzer(catalog: Catalog, Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) + case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(child = f.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) @@ -273,10 +279,9 @@ class Analyzer(catalog: Catalog, case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { - case u @ UnresolvedAttribute(name) - if resolver(name, VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics + case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && + q.isInstanceOf[GroupingAnalytics] => + // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. @@ -299,7 +304,7 @@ class Analyzer(catalog: Catalog, * Returns true if `exprs` contains a [[Star]]. */ protected def containsStar(exprs: Seq[Expression]): Boolean = - exprs.collect { case _: Star => true}.nonEmpty + exprs.exists(_.collect { case _: Star => true }.nonEmpty) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 22941edef2d46..4c5fb3f45bf49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] { .toSet plan transform { - case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance + case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance() } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 71a738a0b2ca0..66060289189ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -50,7 +50,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance = this + override def newInstance() = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this override def withName(newName: String) = UnresolvedAttribute(name) @@ -77,15 +77,10 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E /** * Represents all of the input attributes to a given relational operator, for example in - * "SELECT * FROM ...". - * - * @param table an optional table that should be the target of the expansion. If omitted all - * tables' columns are produced. + * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. */ -case class Star( - table: Option[String], - mapFunction: Attribute => Expression = identity[Attribute]) - extends Attribute with trees.LeafNode[Expression] { +trait Star extends Attribute with trees.LeafNode[Expression] { + self: Product => override def name = throw new UnresolvedException(this, "name") override def exprId = throw new UnresolvedException(this, "exprId") @@ -94,29 +89,53 @@ case class Star( override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance = this + override def newInstance() = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this override def withName(newName: String) = this - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { + // Star gets expanded at runtime so we never evaluate a Star. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] +} + + +/** + * Represents all of the input attributes to a given relational operator, for example in + * "SELECT * FROM ...". + * + * @param table an optional table that should be the target of the expansion. If omitted all + * tables' columns are produced. + */ +case class UnresolvedStar(table: Option[String]) extends Star { + + override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { // If there is no table specified, use all input attributes. case None => input // If there is a table, pick out attributes that are part of this table. case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) } - val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { + expandedAttributes.zip(input).map { case (n: NamedExpression, _) => n case (e, originalAttribute) => Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) } - mappedAttributes } - // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = table.map(_ + ".").getOrElse("") + "*" } + + +/** + * Represents all the resolved input attributes to a given relational operator. This is used + * in the data frame DSL. + * + * @param expressions Expressions to expand. + */ +case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { + override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions + override def toString = expressions.mkString("ResolvedStar(", ", ", ")") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3035d934ff9f8..f388cd5972bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression { * For example the SQL expression "1 + 1 AS a" could be represented as follows: * Alias(Add(Literal(1), Literal(1), "a")() * + * Note that exprId and qualifiers are in a separate parameter list because + * we only pattern match on child and name. + * * @param child the computation being performed * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8df150e2f855f..73ec7a6d114f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -114,7 +114,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") values(i).asInstanceOf[String] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 310d127506d68..b4c445b3badf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -141,10 +141,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + val namedGroupingExpressions: Map[Expression, NamedExpression] = + groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 613f4bb09daf5..5dc0539caec24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,9 +17,24 @@ package org.apache.spark.sql.catalyst.plans +object JoinType { + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + case "inner" => Inner + case "outer" | "full" | "fullouter" => FullOuter + case "leftouter" | "left" => LeftOuter + case "rightouter" | "right" => RightOuter + case "leftsemi" => LeftSemi + } +} + sealed abstract class JoinType + case object Inner extends JoinType + case object LeftOuter extends JoinType + case object RightOuter extends JoinType + case object FullOuter extends JoinType + case object LeftSemi extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index 19769986ef58c..d90af45b375e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { - def apply(output: Attribute*) = - new LocalRelation(output) + def apply(output: Attribute*): LocalRelation = new LocalRelation(output) + + def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( + StructType(output1 +: output).toAttributes + ) } case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 9f30f40a173e0..6ab99aa38877f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -930,13 +930,13 @@ case class MapType( * * This interface allows a user to make their own classes more interoperable with SparkSQL; * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a SchemaRDD which has class X in the schema. + * a `DataFrame` which has class X in the schema. * * For SparkSQL to recognize UDTs, the UDT must be annotated with * [[SQLUserDefinedType]]. * - * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. - * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. + * The conversion via `deserialize` occurs when reading from a `DataFrame`. */ @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5138942a55daa..4a66716e0a782 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,6 +60,7 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], + arrayField1: Array[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], @@ -131,6 +132,10 @@ class ScalaReflectionSuite extends FunSuite { "arrayField", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField1", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3aea337460d42..60060bf02913b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -51,7 +51,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("union project *") { val plan = (1 to 100) .map(_ => testRelation) - .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None)))) + .fold[LogicalPlan](testRelation) { (a, b) => + a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) + } assert(caseInsensitiveAnalyze(plan).resolved) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index e715d9434a2ab..f1949aa5dd74b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.util.concurrent.locks.ReentrantReadWriteLock +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -32,9 +33,10 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR * results when subsequent queries are executed. Data is cached using byte buffers stored in an * InMemoryRelation. This relation is automatically substituted query plans that return the * `sameResult` as the originally cached query. + * + * Internal to Spark SQL. */ -private[sql] trait CacheManager { - self: SQLContext => +private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -43,13 +45,13 @@ private[sql] trait CacheManager { private val cacheLock = new ReentrantReadWriteLock /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty + def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(table(tableName), Some(tableName)) + def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName)) /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) + def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName)) /** Acquires a read lock on the cache for the duration of `f`. */ private def readLock[A](f: => A): A = { @@ -80,7 +82,7 @@ private[sql] trait CacheManager { * the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: SchemaRDD, + query: DataFrame, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -91,16 +93,16 @@ private[sql] trait CacheManager { CachedData( planToCache, InMemoryRelation( - conf.useCompression, - conf.columnBatchSize, + sqlContext.conf.useCompression, + sqlContext.conf.columnBatchSize, storageLevel, query.queryExecution.executedPlan, tableName)) } } - /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[DataFrame]] from the cache */ + private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -108,9 +110,9 @@ private[sql] trait CacheManager { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ + /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ private[sql] def tryUncacheQuery( - query: SchemaRDD, + query: DataFrame, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -122,8 +124,8 @@ private[sql] trait CacheManager { found } - /** Optionally returns cached data for the given SchemaRDD */ - private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[DataFrame]] */ + private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala new file mode 100644 index 0000000000000..174c403059510 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -0,0 +1,584 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.sql.Dsl.lit +import org.apache.spark.sql.catalyst.analysis.{UnresolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.types._ + + +object Column { + /** + * Creates a [[Column]] based on the given column name. Same as [[Dsl.col]]. + */ + def apply(colName: String): Column = new Column(colName) + + /** For internal pattern matching. */ + private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr) +} + + +/** + * A column in a [[DataFrame]]. + * + * `Column` instances can be created by: + * {{{ + * // 1. Select a column out of a DataFrame + * df("colName") + * + * // 2. Create a literal expression + * Literal(1) + * + * // 3. Create new columns from + * }}} + * + */ +// TODO: Improve documentation. +class Column( + sqlContext: Option[SQLContext], + plan: Option[LogicalPlan], + protected[sql] val expr: Expression) + extends DataFrame(sqlContext, plan) with ExpressionApi { + + /** Turns a Catalyst expression into a `Column`. */ + protected[sql] def this(expr: Expression) = this(None, None, expr) + + /** + * Creates a new `Column` expression based on a column or attribute name. + * The resolution of this is the same as SQL. For example: + * + * - "colName" becomes an expression selecting the column named "colName". + * - "*" becomes an expression selecting all columns. + * - "df.*" becomes an expression selecting all columns in data frame "df". + */ + def this(name: String) = this(name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) + + override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined + + /** + * An implicit conversion function internal to this class. This function creates a new Column + * based on an expression. If the expression itself is not named, it aliases the expression + * by calling it "col". + */ + private[this] implicit def toColumn(expr: Expression): Column = { + val projectedPlan = plan.map { p => + Project(Seq(expr match { + case named: NamedExpression => named + case unnamed: Expression => Alias(unnamed, "col")() + }), p) + } + new Column(sqlContext, projectedPlan, expr) + } + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * df.select( -df("amount") ) + * }}} + */ + override def unary_- : Column = UnaryMinus(expr) + + /** + * Bitwise NOT. + * {{{ + * // Select the flags column and negate every bit. + * df.select( ~df("flags") ) + * }}} + */ + override def unary_~ : Column = BitwiseNot(expr) + + /** + * Inversion of boolean expression, i.e. NOT. + * {{ + * // Select rows that are not active (isActive === false) + * df.select( !df("isActive") ) + * }} + */ + override def unary_! : Column = Not(expr) + + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def === (other: Column): Column = EqualTo(expr, other.expr) + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def === (literal: Any): Column = this === lit(literal) + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def equalTo(other: Column): Column = this === other + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def equalTo(literal: Any): Column = this === literal + + /** + * Inequality test with an expression. + * {{{ + * // The following two both select rows in which colA does not equal colB. + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * }}} + */ + override def !== (other: Column): Column = Not(EqualTo(expr, other.expr)) + + /** + * Inequality test with a literal value. + * {{{ + * // The following two both select rows in which colA does not equal equal 15. + * df.select( df("colA") !== 15 ) + * df.select( !(df("colA") === 15) ) + * }}} + */ + override def !== (literal: Any): Column = this !== lit(literal) + + /** + * Greater than an expression. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > Literal(21) ) + * }}} + */ + override def > (other: Column): Column = GreaterThan(expr, other.expr) + + /** + * Greater than a literal value. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > 21 ) + * }}} + */ + override def > (literal: Any): Column = this > lit(literal) + + /** + * Less than an expression. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < Literal(21) ) + * }}} + */ + override def < (other: Column): Column = LessThan(expr, other.expr) + + /** + * Less than a literal value. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < 21 ) + * }}} + */ + override def < (literal: Any): Column = this < lit(literal) + + /** + * Less than or equal to an expression. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= Literal(21) ) + * }}} + */ + override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr) + + /** + * Less than or equal to a literal value. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * }}} + */ + override def <= (literal: Any): Column = this <= lit(literal) + + /** + * Greater than or equal to an expression. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= Literal(21) ) + * }}} + */ + override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr) + + /** + * Greater than or equal to a literal value. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * }}} + */ + override def >= (literal: Any): Column = this >= lit(literal) + + /** + * Equality test with an expression that is safe for null values. + */ + override def <=> (other: Column): Column = other match { + case null => EqualNullSafe(expr, lit(null).expr) + case _ => EqualNullSafe(expr, other.expr) + } + + /** + * Equality test with a literal value that is safe for null values. + */ + override def <=> (literal: Any): Column = this <=> lit(literal) + + /** + * True if the current expression is null. + */ + override def isNull: Column = IsNull(expr) + + /** + * True if the current expression is NOT null. + */ + override def isNotNull: Column = IsNotNull(expr) + + /** + * Boolean OR with an expression. + * {{{ + * // The following selects people that are in school or employed. + * people.select( people("inSchool") || people("isEmployed") ) + * }}} + */ + override def || (other: Column): Column = Or(expr, other.expr) + + /** + * Boolean OR with a literal value. + * {{{ + * // The following selects everything. + * people.select( people("inSchool") || true ) + * }}} + */ + override def || (literal: Boolean): Column = this || lit(literal) + + /** + * Boolean AND with an expression. + * {{{ + * // The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * }}} + */ + override def && (other: Column): Column = And(expr, other.expr) + + /** + * Boolean AND with a literal value. + * {{{ + * // The following selects people that are in school. + * people.select( people("inSchool") && true ) + * }}} + */ + override def && (literal: Boolean): Column = this && lit(literal) + + /** + * Bitwise AND with an expression. + */ + override def & (other: Column): Column = BitwiseAnd(expr, other.expr) + + /** + * Bitwise AND with a literal value. + */ + override def & (literal: Any): Column = this & lit(literal) + + /** + * Bitwise OR with an expression. + */ + override def | (other: Column): Column = BitwiseOr(expr, other.expr) + + /** + * Bitwise OR with a literal value. + */ + override def | (literal: Any): Column = this | lit(literal) + + /** + * Bitwise XOR with an expression. + */ + override def ^ (other: Column): Column = BitwiseXor(expr, other.expr) + + /** + * Bitwise XOR with a literal value. + */ + override def ^ (literal: Any): Column = this ^ lit(literal) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * }}} + */ + override def + (other: Column): Column = Add(expr, other.expr) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and 10. + * people.select( people("height") + 10 ) + * }}} + */ + override def + (literal: Any): Column = this + lit(literal) + + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * }}} + */ + override def - (other: Column): Column = Subtract(expr, other.expr) + + /** + * Subtraction. Subtract a literal value from this expression. + * {{{ + * // The following selects a person's height and subtract it by 10. + * people.select( people("height") - 10 ) + * }}} + */ + override def - (literal: Any): Column = this - lit(literal) + + /** + * Multiplication of this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * }}} + */ + override def * (other: Column): Column = Multiply(expr, other.expr) + + /** + * Multiplication this expression and a literal value. + * {{{ + * // The following multiplies a person's height by 10. + * people.select( people("height") * 10 ) + * }}} + */ + override def * (literal: Any): Column = this * lit(literal) + + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * }}} + */ + override def / (other: Column): Column = Divide(expr, other.expr) + + /** + * Division this expression by a literal value. + * {{{ + * // The following divides a person's height by 10. + * people.select( people("height") / 10 ) + * }}} + */ + override def / (literal: Any): Column = this / lit(literal) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (other: Column): Column = Remainder(expr, other.expr) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (literal: Any): Column = this % lit(literal) + + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + */ + @scala.annotation.varargs + override def in(list: Column*): Column = In(expr, list.map(_.expr)) + + override def like(literal: String): Column = Like(expr, lit(literal).expr) + + override def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + + /** + * An expression that gets an item at position `ordinal` out of an array. + */ + override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) + + /** + * An expression that gets a field by name in a [[StructField]]. + */ + override def getField(fieldName: String): Column = GetField(expr, fieldName) + + /** + * An expression that returns a substring. + * @param startPos expression for the starting position. + * @param len expression for the length of the substring. + */ + override def substr(startPos: Column, len: Column): Column = + Substring(expr, startPos.expr, len.expr) + + /** + * An expression that returns a substring. + * @param startPos starting position. + * @param len length of the substring. + */ + override def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len)) + + override def contains(other: Column): Column = Contains(expr, other.expr) + + override def contains(literal: Any): Column = this.contains(lit(literal)) + + + override def startsWith(other: Column): Column = StartsWith(expr, other.expr) + + override def startsWith(literal: String): Column = this.startsWith(lit(literal)) + + override def endsWith(other: Column): Column = EndsWith(expr, other.expr) + + override def endsWith(literal: String): Column = this.endsWith(lit(literal)) + + /** + * Gives the column an alias. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".as("colB")) + * }}} + */ + override def as(alias: String): Column = Alias(expr, alias)() + + /** + * Casts the column to a different data type. + * {{{ + * // Casts colA to IntegerType. + * import org.apache.spark.sql.types.IntegerType + * df.select(df("colA").cast(IntegerType)) + * + * // equivalent to + * df.select(df("colA").cast("int")) + * }}} + */ + override def cast(to: DataType): Column = Cast(expr, to) + + /** + * Casts the column to a different data type, using the canonical string representation + * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, + * `float`, `double`, `decimal`, `date`, `timestamp`. + * {{{ + * // Casts colA to integer. + * df.select(df("colA").cast("int")) + * }}} + */ + override def cast(to: String): Column = Cast(expr, to.toLowerCase match { + case "string" => StringType + case "boolean" => BooleanType + case "byte" => ByteType + case "short" => ShortType + case "int" => IntegerType + case "long" => LongType + case "float" => FloatType + case "double" => DoubleType + case "decimal" => DecimalType.Unlimited + case "date" => DateType + case "timestamp" => TimestampType + case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") + }) + + override def desc: Column = SortOrder(expr, Descending) + + override def asc: Column = SortOrder(expr, Ascending) +} + + +class ColumnName(name: String) extends Column(name) { + + /** Creates a new AttributeReference of type boolean */ + def boolean: StructField = StructField(name, BooleanType) + + /** Creates a new AttributeReference of type byte */ + def byte: StructField = StructField(name, ByteType) + + /** Creates a new AttributeReference of type short */ + def short: StructField = StructField(name, ShortType) + + /** Creates a new AttributeReference of type int */ + def int: StructField = StructField(name, IntegerType) + + /** Creates a new AttributeReference of type long */ + def long: StructField = StructField(name, LongType) + + /** Creates a new AttributeReference of type float */ + def float: StructField = StructField(name, FloatType) + + /** Creates a new AttributeReference of type double */ + def double: StructField = StructField(name, DoubleType) + + /** Creates a new AttributeReference of type string */ + def string: StructField = StructField(name, StringType) + + /** Creates a new AttributeReference of type date */ + def date: StructField = StructField(name, DateType) + + /** Creates a new AttributeReference of type decimal */ + def decimal: StructField = StructField(name, DecimalType.Unlimited) + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int): StructField = + StructField(name, DecimalType(precision, scale)) + + /** Creates a new AttributeReference of type timestamp */ + def timestamp: StructField = StructField(name, TimestampType) + + /** Creates a new AttributeReference of type binary */ + def binary: StructField = StructField(name, BinaryType) + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): StructField = + map(MapType(keyType, valueType)) + + def map(mapType: MapType): StructField = StructField(name, mapType) + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): StructField = struct(StructType(fields)) + + def struct(structType: StructType): StructField = StructField(name, structType) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala new file mode 100644 index 0000000000000..1096e396591df --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -0,0 +1,664 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.{List => JList} + +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.util.Utils + + +/** + * A collection of rows that have the same columns. + * + * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using + * various functions in [[SQLContext]]. + * {{{ + * val people = sqlContext.parquetFile("...") + * }}} + * + * Once created, it can be manipulated using the various domain-specific-language (DSL) functions + * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL. + * + * To select a column from the data frame, use the apply method: + * {{{ + * val ageCol = people("age") // in Scala + * Column ageCol = people.apply("age") // in Java + * }}} + * + * Note that the [[Column]] type can also be manipulated through its various functions. + * {{ + * // The following creates a new column that increases everybody's age by 10. + * people("age") + 10 // in Scala + * }} + * + * A more concrete example: + * {{{ + * // To create DataFrame using SQLContext + * val people = sqlContext.parquetFile("...") + * val department = sqlContext.parquetFile("...") + * + * people.filter("age" > 30) + * .join(department, people("deptId") === department("id")) + * .groupBy(department("name"), "gender") + * .agg(avg(people("salary")), max(people("age"))) + * }}} + */ +// TODO: Improve documentation. +class DataFrame protected[sql]( + val sqlContext: SQLContext, + private val baseLogicalPlan: LogicalPlan, + operatorsEnabled: Boolean) + extends DataFrameSpecificApi with RDDApi[Row] { + + protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) = + this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined) + + protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true) + + @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + baseLogicalPlan + } + + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrame(...)" everywhere. + */ + private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan, true) + } + + /** Returns the list of numeric columns, useful for doing aggregation. */ + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + + /** Resolves a column name into a Catalyst [[NamedExpression]]. */ + protected[sql] def resolve(colName: String): NamedExpression = { + queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + throw new RuntimeException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + } + + /** Left here for compatibility reasons. */ + @deprecated("1.3.0", "use toDataFrame") + def toSchemaRDD: DataFrame = this + + /** + * Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala. + */ + def toDataFrame: DataFrame = this + + /** + * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion + * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + */ + @scala.annotation.varargs + def toDataFrame(colName: String, colNames: String*): DataFrame = { + val newNames = colName +: colNames + require(schema.size == newNames.size, + "The number of columns doesn't match.\n" + + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + + "New column names: " + newNames.mkString(", ")) + + val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + apply(oldName).as(newName) + } + select(newCols :_*) + } + + /** Returns the schema of this [[DataFrame]]. */ + override def schema: StructType = queryExecution.analyzed.schema + + /** Returns all column names and their data types as an array. */ + override def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + /** Returns all column names as an array. */ + override def columns: Array[String] = schema.fields.map(_.name) + + /** Prints the schema to the console in a nice tree format. */ + override def printSchema(): Unit = println(schema.treeString) + + /** + * Cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + */ + override def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + /** + * Inner join with another [[DataFrame]], using the given join expression. + * + * {{{ + * // The following two are equivalent: + * df1.join(df2, $"df1Key" === $"df2Key") + * df1.join(df2).where($"df1Key" === $"df2Key") + * }}} + */ + override def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) + } + + /** + * Join with another [[DataFrame]], usin g the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * df1.join(df2, "outer", $"df1Key" === $"df2Key") + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + */ + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } + + /** + * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. + * {{{ + * // The following 3 are equivalent + * df.sort("sortcol") + * df.sort($"sortcol") + * df.sort($"sortcol".asc) + * }}} + */ + @scala.annotation.varargs + override def sort(sortCol: String, sortCols: String*): DataFrame = { + orderBy(apply(sortCol), sortCols.map(apply) :_*) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. For example: + * {{{ + * df.sort($"col1", $"col2".desc) + * }}} + */ + @scala.annotation.varargs + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortCol: String, sortCols: String*): DataFrame = { + sort(sortCol, sortCols :_*) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { + sort(sortExpr, sortExprs :_*) + } + + /** + * Selects column based on the column name and return it as a [[Column]]. + */ + override def apply(colName: String): Column = colName match { + case "*" => + new Column(ResolvedStar(schema.fieldNames.map(resolve))) + case _ => + val expr = resolve(colName) + new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) + } + + /** + * Selects a set of expressions, wrapped in a Product. + * {{{ + * // The following two are equivalent: + * df.apply(($"colA", $"colB" + 1)) + * df.select($"colA", $"colB" + 1) + * }}} + */ + override def apply(projection: Product): DataFrame = { + require(projection.productArity >= 1) + select(projection.productIterator.map { + case c: Column => c + case o: Any => new Column(Some(sqlContext), None, Literal(o)) + }.toSeq :_*) + } + + /** + * Returns a new [[DataFrame]] with an alias set. + */ + override def as(name: String): DataFrame = Subquery(name, logicalPlan) + + /** + * Selects a set of expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + */ + @scala.annotation.varargs + override def select(cols: Column*): DataFrame = { + val exprs = cols.zipWithIndex.map { + case (Column(expr: NamedExpression), _) => + expr + case (Column(expr: Expression), _) => + Alias(expr, expr.toString)() + } + Project(exprs.toSeq, logicalPlan) + } + + /** + * Selects a set of columns. This is a variant of `select` that can only select + * existing columns using column names (i.e. cannot construct expressions). + * + * {{{ + * // The following two are equivalent: + * df.select("colA", "colB") + * df.select($"colA", $"colB") + * }}} + */ + @scala.annotation.varargs + override def select(col: String, cols: String*): DataFrame = { + select((col +: cols).map(new Column(_)) :_*) + } + + /** + * Filters rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def filter(condition: Column): DataFrame = { + Filter(condition.expr, logicalPlan) + } + + /** + * Filters rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def where(condition: Column): DataFrame = filter(condition) + + /** + * Filters rows using the given condition. This is a shorthand meant for Scala. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def apply(condition: Column): DataFrame = filter(condition) + + /** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(cols: Column*): GroupedDataFrame = { + new GroupedDataFrame(this, cols.map(_.expr)) + } + + /** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): GroupedDataFrame = { + val colNames: Seq[String] = col1 +: cols + new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) + } + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap) + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(max($"age"), avg($"salary")) + * df.groupBy().agg(max($"age"), avg($"salary")) + * }} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + + /** + * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. + */ + override def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + + /** + * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. + * This is equivalent to `UNION ALL` in SQL. + */ + override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. + * This is equivalent to `INTERSECT` in SQL. + */ + override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. + * This is equivalent to `EXCEPT` in SQL. + */ + override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] by sampling a fraction of rows. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @param seed Seed for sampling. + */ + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + */ + override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns a new [[DataFrame]] by adding a column. + */ + override def addColumn(colName: String, col: Column): DataFrame = { + select(Column("*"), col.as(colName)) + } + + /** + * Returns the first `n` rows. + */ + override def head(n: Int): Array[Row] = limit(n).collect() + + /** + * Returns the first row. + */ + override def head(): Row = head(1).head + + /** + * Returns the first row. Alias for head(). + */ + override def first(): Row = head() + + /** + * Returns a new RDD by applying a function to all rows of this DataFrame. + */ + override def map[R: ClassTag](f: Row => R): RDD[R] = { + rdd.map(f) + } + + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + */ + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + /** + * Returns a new RDD by applying a function to each partition of this DataFrame. + */ + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Applies a function `f` to all rows. + */ + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + /** + * Applies a function f to each partition of this [[DataFrame]]. + */ + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + + /** + * Returns the first `n` rows in the [[DataFrame]]. + */ + override def take(n: Int): Array[Row] = head(n) + + /** + * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collect(): Array[Row] = rdd.collect() + + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + + /** + * Returns the number of rows in the [[DataFrame]]. + */ + override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + + /** + * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + */ + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.applySchema(rdd.repartition(numPartitions), schema) + } + + override def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + ///////////////////////////////////////////////////////////////////////////// + // I/O + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. + */ + override def rdd: RDD[Row] = { + val schema = this.schema + queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * + * @group schema + */ + override def registerTempTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + /** + * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. + * Files that are written out using this method can be read back in as a [[DataFrame]] + * using the `parquetFile` function in [[SQLContext]]. + */ + override def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame. This will fail if the table already + * exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + override def saveAsTable(tableName: String): Unit = { + sqlContext.executePlan( + CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + */ + @Experimental + override def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd + } + + /** + * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + */ + override def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + /** + * A helpful function for Py4j, convert a list of Column to an array + */ + protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = { + cols.toList.toArray + } + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala new file mode 100644 index 0000000000000..3499956023d11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -0,0 +1,529 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +/** + * Domain specific functions available for [[DataFrame]]. + */ +object Dsl { + + /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + // /** + // * An implicit conversion that turns a RDD of product into a [[DataFrame]]. + // * + // * This method requires an implicit SQLContext in scope. For example: + // * {{{ + // * implicit val sqlContext: SQLContext = ... + // * val rdd: RDD[(Int, String)] = ... + // * rdd.toDataFrame // triggers the implicit here + // * }}} + // */ + // implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext) + // : DataFrame = { + // context.createDataFrame(rdd) + // } + + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) extends AnyVal { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + + /** + * Returns a [[Column]] based on the given column name. + */ + def col(colName: String): Column = new Column(colName) + + /** + * Returns a [[Column]] based on the given column name. Alias of [[col]]. + */ + def column(colName: String): Column = new Column(colName) + + /** + * Creates a [[Column]] of literal value. + */ + def lit(literal: Any): Column = { + if (literal.isInstanceOf[Symbol]) { + return new ColumnName(literal.asInstanceOf[Symbol].name) + } + + val literalExpr = literal match { + case v: Boolean => Literal(v, BooleanType) + case v: Byte => Literal(v, ByteType) + case v: Short => Literal(v, ShortType) + case v: Int => Literal(v, IntegerType) + case v: Long => Literal(v, LongType) + case v: Float => Literal(v, FloatType) + case v: Double => Literal(v, DoubleType) + case v: String => Literal(v, StringType) + case v: BigDecimal => Literal(Decimal(v), DecimalType.Unlimited) + case v: java.math.BigDecimal => Literal(Decimal(v), DecimalType.Unlimited) + case v: Decimal => Literal(v, DecimalType.Unlimited) + case v: java.sql.Timestamp => Literal(v, TimestampType) + case v: java.sql.Date => Literal(v, DateType) + case v: Array[Byte] => Literal(v, BinaryType) + case null => Literal(null, NullType) + case _ => + throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal) + } + new Column(literalExpr) + } + + def sum(e: Column): Column = Sum(e.expr) + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def count(e: Column): Column = Count(e.expr) + + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + def approxCountDistinct(e: Column): Column = + ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column, rsd: Double): Column = + ApproxCountDistinct(e.expr, rsd) + + def avg(e: Column): Column = Average(e.expr) + def first(e: Column): Column = First(e.expr) + def last(e: Column): Column = Last(e.expr) + def min(e: Column): Column = Min(e.expr) + def max(e: Column): Column = Max(e.expr) + + def upper(e: Column): Column = Upper(e.expr) + def lower(e: Column): Column = Lower(e.expr) + def sqrt(e: Column): Column = Sqrt(e.expr) + def abs(e: Column): Column = Abs(e.expr) + + + // scalastyle:off + + /* Use the following code to generate: + (0 to 22).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) + }""") + } + + (0 to 22).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } + */ + /** + * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag](f: Function0[RT]): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala new file mode 100644 index 0000000000000..1c948cbbfe58f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate + + +/** + * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. + */ +class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) + extends GroupedDataFrameApi { + + private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { + val namedGroupingExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + new DataFrame(df.sqlContext, + Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + } + + private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { + df.numericColumns.map { c => + val a = f(c) + Alias(a, a.toString)() + } + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + expr.toLowerCase match { + case "avg" | "average" | "mean" => Average + case "max" => Max + case "min" => Min + case "sum" => Sum + case "count" | "size" => Count + } + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. The resulting + * [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + override def agg(exprs: Map[String, String]): DataFrame = { + exprs.map { case (colName, expr) => + val a = strToExpr(expr)(df(colName).expr) + Alias(a, a.toString)() + }.toSeq + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. The resulting + * [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this + * class, the resulting [[DataFrame]] won't automatically include the grouping columns. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import org.apache.spark.sql.dsl._ + * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) + * }}} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = { + val aggExprs = (expr +: exprs).map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + + new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } + + /** + * Count the number of rows for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def mean(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the max value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def max(): DataFrame = aggregateNumericColumns(Max) + + /** + * Compute the mean value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def avg(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the min value for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def min(): DataFrame = aggregateNumericColumns(Min) + + /** + * Compute the sum for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def sum(): DataFrame = aggregateNumericColumns(Sum) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0a22968cc7807..f87fde4ed8165 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -30,7 +30,6 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -43,7 +42,7 @@ import org.apache.spark.util.Utils /** * :: AlphaComponent :: - * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] + * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]] * objects and the execution of SQL queries. * * @groupname userf Spark SQL Functions @@ -52,8 +51,6 @@ import org.apache.spark.util.Utils @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging - with CacheManager - with ExpressionConversions with Serializable { self => @@ -111,37 +108,82 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + + protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) case _ => } + protected[sql] val cacheManager = new CacheManager(this) + + /** + * A collection of methods that are considered experimental, but can be used to hook into + * the query planner for advanced functionalities. + */ + val experimental: ExperimentalMethods = new ExperimentalMethods(this) + /** - * Creates a SchemaRDD from an RDD of case classes. + * A collection of methods for registering user-defined functions (UDF). + * + * The following example registers a Scala closure as UDF: + * {{{ + * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * }}} + * + * The following example registers a UDF in Java: + * {{{ + * sqlContext.udf().register("myUDF", + * new UDF2() { + * @Override + * public String call(Integer arg1, String arg2) { + * return arg2 + arg1; + * } + * }, DataTypes.StringType); + * }}} + * + * Or, to use Java 8 lambda syntax: + * {{{ + * sqlContext.udf().register("myUDF", + * (Integer arg1, String arg2) -> arg2 + arg1), + * DataTypes.StringType); + * }}} + */ + val udf: UDFRegistration = new UDFRegistration(this) + + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + + /** + * Creates a DataFrame from an RDD of case classes. * * @group userf */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = { + implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) val attributeSeq = ScalaReflection.attributesFor[A] val schema = StructType.fromAttributes(attributeSeq) val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** - * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]]. + * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. */ - def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { - new SchemaRDD(this, LogicalRelation(baseRelation)) + def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { + new DataFrame(this, LogicalRelation(baseRelation)) } /** * :: DeveloperApi :: - * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * Example: @@ -157,24 +199,24 @@ class SQLContext(@transient val sparkContext: SparkContext) * val people = * sc.textFile("examples/src/main/resources/people.txt").map( * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val peopleSchemaRDD = sqlContext. applySchema(people, schema) - * peopleSchemaRDD.printSchema + * val dataFrame = sqlContext. applySchema(people, schema) + * dataFrame.printSchema * // root * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * peopleSchemaRDD.registerTempTable("people") + * dataFrame.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * * @group userf */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { - // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) - new SchemaRDD(this, logicalPlan) + new DataFrame(this, logicalPlan) } /** @@ -183,7 +225,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -201,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ) : Row } } - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -210,35 +252,35 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { applySchema(rdd.rdd, beanClass) } /** - * Loads a Parquet file, returning the result as a [[SchemaRDD]]. + * Loads a Parquet file, returning the result as a [[DataFrame]]. * * @group userf */ - def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + def parquetFile(path: String): DataFrame = + new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** - * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) /** * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonFile(path: String, schema: StructType): SchemaRDD = { + def jsonFile(path: String, schema: StructType): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, schema) } @@ -247,29 +289,29 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + def jsonFile(path: String, samplingRatio: Double): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, samplingRatio) } /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[SchemaRDD]]. + * [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) /** * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = Option(schema).getOrElse( @@ -283,7 +325,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = JsonRDD.nullTypeToStringType( @@ -298,8 +340,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), rdd.queryExecution.logical) + def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), rdd.logicalPlan) } /** @@ -311,61 +353,27 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def dropTempTable(tableName: String): Unit = { - tryUncacheQuery(table(tableName)) + cacheManager.tryUncacheQuery(table(tableName)) catalog.unregisterTable(Seq(tableName)) } /** - * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. * * @group userf */ - def sql(sqlText: String): SchemaRDD = { + def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { - new SchemaRDD(this, parseSql(sqlText)) + new DataFrame(this, parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } } - /** Returns the specified table as a SchemaRDD */ - def table(tableName: String): SchemaRDD = - new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) - - /** - * A collection of methods that are considered experimental, but can be used to hook into - * the query planner for advanced functionalities. - */ - val experimental: ExperimentalMethods = new ExperimentalMethods(this) - - /** - * A collection of methods for registering user-defined functions (UDF). - * - * The following example registers a Scala closure as UDF: - * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) - * }}} - * - * The following example registers a UDF in Java: - * {{{ - * sqlContext.udf().register("myUDF", - * new UDF2() { - * @Override - * public String call(Integer arg1, String arg2) { - * return arg2 + arg1; - * } - * }, DataTypes.StringType); - * }}} - * - * Or, to use Java 8 lambda syntax: - * {{{ - * sqlContext.udf().register("myUDF", - * (Integer arg1, String arg2) -> arg2 + arg1), - * DataTypes.StringType); - * }}} - */ - val udf: UDFRegistration = new UDFRegistration(this) + /** Returns the specified table as a [[DataFrame]]. */ + def table(tableName: String): DataFrame = + new DataFrame(this, catalog.lookupRelation(Seq(tableName))) protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -454,15 +462,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * access to the intermediate phases of query execution for developers. */ @DeveloperApi - protected abstract class QueryExecution { - def logical: LogicalPlan + protected class QueryExecution(val logical: LogicalPlan) { - lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) - lazy val withCachedData = useCachedData(analyzed) - lazy val optimizedPlan = optimizer(withCachedData) + lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical)) + lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) + lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... - lazy val sparkPlan = { + lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) planner(optimizedPlan).next() } @@ -512,7 +519,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schemaString: String): SchemaRDD = { + schemaString: String): DataFrame = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) } @@ -522,7 +529,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schema: StructType): SchemaRDD = { + schema: StructType): DataFrame = { def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -549,7 +556,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala deleted file mode 100644 index d1e21dffeb8c5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ /dev/null @@ -1,511 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList} - -import scala.collection.JavaConversions._ - -import com.fasterxml.jackson.core.JsonFactory - -import net.razorvine.pickle.Pickler - -import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} -import org.apache.spark.annotation.{AlphaComponent, Experimental} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{BooleanType, StructType} -import org.apache.spark.storage.StorageLevel - -/** - * :: AlphaComponent :: - * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, - * SchemaRDDs can be used in relational queries, as shown in the examples below. - * - * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD - * whose elements are scala case classes into a SchemaRDD. This conversion can also be done - * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. - * - * A `SchemaRDD` can also be created by loading data in from external sources. - * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]] - * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. - * - * == SQL Queries == - * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once - * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. - * - * {{{ - * // One method for defining the schema of an RDD is to make a case class with the desired column - * // names and types. - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - * // Any RDD containing case classes can be registered as a table. The schema of the table is - * // automatically inferred using scala reflection. - * rdd.registerTempTable("records") - * - * val results: SchemaRDD = sql("SELECT * FROM records") - * }}} - * - * == Language Integrated Queries == - * - * {{{ - * - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) - * - * // Example of language integrated queries. - * rdd.where('key === 1).orderBy('value.asc).select('key).collect() - * }}} - * - * @groupname Query Language Integrated Queries - * @groupdesc Query Functions that create new queries from SchemaRDDs. The - * result of all query functions is also a SchemaRDD, allowing multiple operations to be - * chained using a builder pattern. - * @groupprio Query -2 - * @groupname schema SchemaRDD Functions - * @groupprio schema -1 - * @groupname Ungrouped Base RDD Functions - */ -@AlphaComponent -class SchemaRDD( - @transient val sqlContext: SQLContext, - @transient val baseLogicalPlan: LogicalPlan) - extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { - - def baseSchemaRDD = this - - // ========================================================================================= - // RDD functions: Copy the internal row representation so we present immutable data to users. - // ========================================================================================= - - override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) - - override def getPartitions: Array[Partition] = firstParent[Row].partitions - - override protected def getDependencies: Seq[Dependency[_]] = { - schema // Force reification of the schema so it is available on executors. - - List(new OneToOneDependency(queryExecution.toRdd)) - } - - /** - * Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - lazy val schema: StructType = queryExecution.analyzed.schema - - /** - * Returns a new RDD with each row transformed to a JSON string. - * - * @group schema - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } - - - // ======================================================================= - // Query DSL - // ======================================================================= - - /** - * Changes the output of this relation to the given expressions, similar to the `SELECT` clause - * in SQL. - * - * {{{ - * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) - * }}} - * - * @param exprs a set of logical expression that will be evaluated for each input row. - * - * @group Query - */ - def select(exprs: Expression*): SchemaRDD = { - val aliases = exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) - } - - /** - * Filters the output, only returning those rows where `condition` evaluates to true. - * - * {{{ - * schemaRDD.where('a === 'b) - * schemaRDD.where('a === 1) - * schemaRDD.where('a + 'b > 10) - * }}} - * - * @group Query - */ - def where(condition: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) - - /** - * Performs a relational join on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be joined with this one. - * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` - * @param on An optional condition for the join operation. This is equivalent to the `ON` - * clause in standard SQL. In the case of `Inner` joins, specifying a - * `condition` is equivalent to adding `where` clauses after the `join`. - * - * @group Query - */ - def join( - otherPlan: SchemaRDD, - joinType: JoinType = Inner, - on: Option[Expression] = None): SchemaRDD = - new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on)) - - /** - * Sorts the results by the given expressions. - * {{{ - * schemaRDD.orderBy('a) - * schemaRDD.orderBy('a, 'b) - * schemaRDD.orderBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) - - /** - * Sorts the results by the given expressions within partition. - * {{{ - * schemaRDD.sortBy('a) - * schemaRDD.sortBy('a, 'b) - * schemaRDD.sortBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) - - @deprecated("use limit with integer argument", "1.1.0") - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) - - /** - * Limits the results by the given integer. - * {{{ - * schemaRDD.limit(10) - * }}} - * @group Query - */ - def limit(limitNum: Int): SchemaRDD = - new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) - - /** - * Performs a grouping followed by an aggregation. - * - * {{{ - * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() - } - new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) - } - - /** - * Performs an aggregation over all Rows in this RDD. - * This is equivalent to a groupBy with no grouping expressions. - * - * {{{ - * schemaRDD.aggregate(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def aggregate(aggregateExprs: Expression*): SchemaRDD = { - groupBy()(aggregateExprs: _*) - } - - /** - * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes - * with the same name, for example, when performing self-joins. - * - * {{{ - * val x = schemaRDD.where('a === 1).as('x) - * val y = schemaRDD.where('a === 2).as('y) - * x.join(y).where("x.a".attr === "y.a".attr), - * }}} - * - * @group Query - */ - def as(alias: Symbol) = - new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) - - /** - * Combines the tuples of two RDDs with the same schema, keeping duplicates. - * - * @group Query - */ - def unionAll(otherPlan: SchemaRDD) = - new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational except on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be excepted from this one. - * - * @group Query - */ - def except(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational intersect on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be intersected with this one. - * - * @group Query - */ - def intersect(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan)) - - /** - * Filters tuples using a function over the value of the specified column. - * - * {{{ - * schemaRDD.where('a)((a: Int) => ...) - * }}} - * - * @group Query - */ - def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - - /** - * :: Experimental :: - * Returns a sampled version of the underlying dataset. - * - * @group Query - */ - @Experimental - override - def sample( - withReplacement: Boolean = true, - fraction: Double, - seed: Long) = - new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) - - /** - * :: Experimental :: - * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this - * implementation leverages the query optimizer to compute the count on the SchemaRDD, which - * supports features such as filter pushdown. - * - * @group Query - */ - @Experimental - override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0) - - /** - * :: Experimental :: - * Applies the given Generator, or table generating function, to this relation. - * - * @param generator A table generating function. The API for such functions is likely to change - * in future releases - * @param join when set to true, each output row of the generator is joined with the input row - * that produced it. - * @param outer when set to true, at least one row will be produced for each input row, similar to - * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a - * given row, a single row will be output, with `NULL` values for each of the - * generated columns. - * @param alias an optional alias that can be used as qualifier for the attributes that are - * produced by this generate operation. - * - * @group Query - */ - @Experimental - def generate( - generator: Generator, - join: Boolean = false, - outer: Boolean = false, - alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) - - /** - * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit - * conversion from a standard RDD to a SchemaRDD. - * - * @group schema - */ - def toSchemaRDD = this - - /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. - */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same - * format as javaToPython. It is used by pyspark. - */ - private[sql] def collectToPython: JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(collect().map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same - * format as javaToPython and collectToPython. It is used by pyspark. - */ - private[sql] def takeSampleToPython( - withReplacement: Boolean, - num: Int, - seed: Long): JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value - * of base RDD functions that do not change schema. - * - * @param rdd RDD derived from this one and has same schema - * - * @group schema - */ - private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, - LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) - } - - // ======================================================================= - // Overridden RDD actions - // ======================================================================= - - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() - - def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*) - - override def take(num: Int): Array[Row] = limit(num).collect() - - // ======================================================================= - // Base RDD functions that do NOT change schema - // ======================================================================= - - // Transformations (return a new RDD) - - override def coalesce(numPartitions: Int, shuffle: Boolean = false) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.coalesce(numPartitions, shuffle)(ord)) - - override def distinct(): SchemaRDD = applySchema(super.distinct()) - - override def distinct(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.distinct(numPartitions)(ord)) - - def distinct(numPartitions: Int): SchemaRDD = - applySchema(super.distinct(numPartitions)(null)) - - override def filter(f: Row => Boolean): SchemaRDD = - applySchema(super.filter(f)) - - override def intersection(other: RDD[Row]): SchemaRDD = - applySchema(super.intersection(other)) - - override def intersection(other: RDD[Row], partitioner: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.intersection(other, partitioner)(ord)) - - override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.intersection(other, numPartitions)) - - override def repartition(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.repartition(numPartitions)(ord)) - - override def subtract(other: RDD[Row]): SchemaRDD = - applySchema(super.subtract(other)) - - override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.subtract(other, numPartitions)) - - override def subtract(other: RDD[Row], p: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.subtract(other, p)(ord)) - - /** Overridden cache function will always use the in-memory columnar caching. */ - override def cache(): this.type = { - sqlContext.cacheQuery(this) - this - } - - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheQuery(this, None, newLevel) - this - } - - override def unpersist(blocking: Boolean): this.type = { - sqlContext.tryUncacheQuery(this, blocking) - this - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala deleted file mode 100644 index 3cf9209465b76..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.LogicalRDD - -/** - * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) - */ -private[sql] trait SchemaRDDLike { - @transient def sqlContext: SQLContext - @transient val baseLogicalPlan: LogicalPlan - - private[sql] def baseSchemaRDD: SchemaRDD - - /** - * :: DeveloperApi :: - * A lazily computed query execution workflow. All other RDD operations are passed - * through to the RDD that is produced by this workflow. This workflow is produced lazily because - * invoking the whole query optimization pipeline can be expensive. - * - * The query execution is considered a Developer API as phases may be added or removed in future - * releases. This execution is only exposed to provide an interface for inspecting the various - * phases for debugging purposes. Applications should not depend on particular phases existing - * or producing any specific output, even for exactly the same query. - * - * Additionally, the RDD exposed by this execution is not designed for consumption by end users. - * In particular, it does not contain any schema information, and it reuses Row objects - * internally. This object reuse improves performance, but can make programming against the RDD - * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly. - */ - @transient - @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) - - @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - baseLogicalPlan - } - - override def toString = - s"""${super.toString} - |== Query Plan == - |${queryExecution.simpleString}""".stripMargin.trim - - /** - * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that - * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` - * function. - * - * @group schema - */ - def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. - * - * @group schema - */ - def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) - } - - @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") - def registerAsTable(tableName: String): Unit = registerTempTable(tableName) - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - - /** - * :: Experimental :: - * Appends the rows from this RDD to the specified table. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * :: Experimental :: - * Creates a table from the the contents of this SchemaRDD. This will fail if the table already - * exists. - * - * Note that this currently only works with SchemaRDDs that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * @group schema - */ - @Experimental - def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd - - /** Returns the schema as a string in the tree format. - * - * @group schema - */ - def schemaString: String = baseSchemaRDD.schema.treeString - - /** Prints out the schema. - * - * @group schema - */ - def printSchema(): Unit = println(schemaString) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 2e9d037f93c03..1beb19437a8da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -21,7 +21,7 @@ import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.DataType /** * Functions for registering user-defined functions. */ -class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging { +class UDFRegistration(sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala new file mode 100644 index 0000000000000..eb0eb3f32560c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -0,0 +1,299 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.storage.StorageLevel + + +/** + * An internal interface defining the RDD-like methods for [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +private[sql] trait RDDApi[T] { + + def cache(): this.type = persist() + + def persist(): this.type + + def persist(newLevel: StorageLevel): this.type + + def unpersist(): this.type = unpersist(blocking = false) + + def unpersist(blocking: Boolean): this.type + + def map[R: ClassTag](f: T => R): RDD[R] + + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + + def take(n: Int): Array[T] + + def collect(): Array[T] + + def collectAsList(): java.util.List[T] + + def count(): Long + + def first(): T + + def repartition(numPartitions: Int): DataFrame +} + + +/** + * An internal interface defining data frame related methods in [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +private[sql] trait DataFrameSpecificApi { + + def schema: StructType + + def printSchema(): Unit + + def dtypes: Array[(String, String)] + + def columns: Array[String] + + def head(): Row + + def head(n: Int): Array[Row] + + ///////////////////////////////////////////////////////////////////////////// + // Relational operators + ///////////////////////////////////////////////////////////////////////////// + def apply(colName: String): Column + + def apply(projection: Product): DataFrame + + @scala.annotation.varargs + def select(cols: Column*): DataFrame + + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame + + def apply(condition: Column): DataFrame + + def as(name: String): DataFrame + + def filter(condition: Column): DataFrame + + def where(condition: Column): DataFrame + + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataFrame + + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): GroupedDataFrame + + def agg(exprs: Map[String, String]): DataFrame + + def agg(exprs: java.util.Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortCol: String, sortCols: String*): DataFrame + + @scala.annotation.varargs + def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def orderBy(sortCol: String, sortCols: String*): DataFrame + + def join(right: DataFrame): DataFrame + + def join(right: DataFrame, joinExprs: Column): DataFrame + + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame + + def limit(n: Int): DataFrame + + def unionAll(other: DataFrame): DataFrame + + def intersect(other: DataFrame): DataFrame + + def except(other: DataFrame): DataFrame + + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame + + def sample(withReplacement: Boolean, fraction: Double): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // Column mutation + ///////////////////////////////////////////////////////////////////////////// + def addColumn(colName: String, col: Column): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // I/O and interaction with other frameworks + ///////////////////////////////////////////////////////////////////////////// + + def rdd: RDD[Row] + + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + + def toJSON: RDD[String] + + def registerTempTable(tableName: String): Unit + + def saveAsParquetFile(path: String): Unit + + @Experimental + def saveAsTable(tableName: String): Unit + + @Experimental + def insertInto(tableName: String, overwrite: Boolean): Unit + + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) + + ///////////////////////////////////////////////////////////////////////////// + // Stat functions + ///////////////////////////////////////////////////////////////////////////// +// def describe(): Unit +// +// def mean(): Unit +// +// def max(): Unit +// +// def min(): Unit +} + + +/** + * An internal interface defining expression APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this. + */ +private[sql] trait ExpressionApi { + + def isComputable: Boolean + + def unary_- : Column + def unary_! : Column + def unary_~ : Column + + def + (other: Column): Column + def + (other: Any): Column + def - (other: Column): Column + def - (other: Any): Column + def * (other: Column): Column + def * (other: Any): Column + def / (other: Column): Column + def / (other: Any): Column + def % (other: Column): Column + def % (other: Any): Column + def & (other: Column): Column + def & (other: Any): Column + def | (other: Column): Column + def | (other: Any): Column + def ^ (other: Column): Column + def ^ (other: Any): Column + + def && (other: Column): Column + def && (other: Boolean): Column + def || (other: Column): Column + def || (other: Boolean): Column + + def < (other: Column): Column + def < (other: Any): Column + def <= (other: Column): Column + def <= (other: Any): Column + def > (other: Column): Column + def > (other: Any): Column + def >= (other: Column): Column + def >= (other: Any): Column + def === (other: Column): Column + def === (other: Any): Column + def equalTo(other: Column): Column + def equalTo(other: Any): Column + def <=> (other: Column): Column + def <=> (other: Any): Column + def !== (other: Column): Column + def !== (other: Any): Column + + @scala.annotation.varargs + def in(list: Column*): Column + + def like(other: String): Column + def rlike(other: String): Column + + def contains(other: Column): Column + def contains(other: Any): Column + def startsWith(other: Column): Column + def startsWith(other: String): Column + def endsWith(other: Column): Column + def endsWith(other: String): Column + + def substr(startPos: Column, len: Column): Column + def substr(startPos: Int, len: Int): Column + + def isNull: Column + def isNotNull: Column + + def getItem(ordinal: Int): Column + def getField(fieldName: String): Column + + def cast(to: DataType): Column + def cast(to: String): Column + + def asc: Column + def desc: Column + + def as(alias: String): Column +} + + +/** + * An internal interface defining aggregation APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this. + */ +private[sql] trait GroupedDataFrameApi { + + def agg(exprs: Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + def avg(): DataFrame + + def mean(): DataFrame + + def min(): DataFrame + + def max(): DataFrame + + def sum(): DataFrame + + def count(): DataFrame + + // TODO: Add var, std +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 52a31f01a4358..6fba76c52171b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical @@ -137,7 +137,9 @@ case class CacheTableCommand( isLazy: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName)) + plan.foreach { logicalPlan => + sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName) + } sqlContext.cacheTable(tableName) if (!isLazy) { @@ -159,7 +161,7 @@ case class CacheTableCommand( case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - sqlContext.table(tableName).unpersist() + sqlContext.table(tableName).unpersist(blocking = false) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 4d7e338e8ed13..5cc67cdd13944 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -39,10 +39,10 @@ package object debug { /** * :: DeveloperApi :: - * Augments SchemaRDDs with debug methods. + * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: SchemaRDD) { + implicit class DebugQuery(query: DataFrame) { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -166,7 +166,7 @@ package object debug { /** * :: DeveloperApi :: - * Augments SchemaRDDs with debug methods. + * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 6dd39be807037..7c49b5220d607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -37,5 +37,5 @@ package object sql { * Converts a logical plan into zero or more SparkPlans. */ @DeveloperApi - type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 9d9150246c8d4..10df8c3310092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import parquet.column.Dictionary import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType @@ -102,12 +103,8 @@ private[sql] object CatalystConverter { } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately - case StringType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value) - } - } + case StringType => + new CatalystPrimitiveStringConverter(parent, fieldIndex) case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = @@ -197,8 +194,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + updateField(fieldIndex, value) protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) @@ -384,8 +381,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - current.setString(fieldIndex, value.toStringUsingUTF8) + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + current.setString(fieldIndex, value) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -426,6 +423,33 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. + * Supports dictionaries to reduce Binary to String conversion overhead. + * + * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) + extends CatalystPrimitiveConverter(parent, fieldIndex) { + + private[this] var dict: Array[String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary):Unit = + dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} + + + override def addValueFromDictionary(dictionaryId: Int): Unit = + parent.updateString(fieldIndex, dict(dictionaryId)) + + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value.toStringUsingUTF8) +} + private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } @@ -583,9 +607,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { checkGrowBuffer() - buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + buffer(elements) = value.asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f08350878f239..0357dcc4688be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -164,33 +164,57 @@ private[sql] object ParquetFilters { case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeEq.lift(dataType).map(_(name, value)) + case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeEq.lift(dataType).map(_(name, value)) case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeEq.lift(dataType).map(_(name, value)) - + case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => + makeNotEq.lift(dataType).map(_(name, value)) case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => + makeNotEq.lift(dataType).map(_(name, value)) case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLt.lift(dataType).map(_(name, value)) + case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLt.lift(dataType).map(_(name, value)) case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGt.lift(dataType).map(_(name, value)) + case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGt.lift(dataType).map(_(name, value)) case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) + case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) + case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLt.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLtEq.lift(dataType).map(_(name, value)) case And(lhs, rhs) => (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cde5160149e9c..a54485e719dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -26,7 +26,7 @@ import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -34,8 +34,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati /** * Relation that consists of data stored in a Parquet columnar format. * - * Users should interact with parquet files though a SchemaRDD, created by a [[SQLContext]] instead - * of using this class directly. + * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] + * instead of using this class directly. * * {{{ * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 02ce1b3e6d811..9d6c529574da0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util import org.apache.spark.util.Utils @@ -95,12 +95,12 @@ trait ParquetTest { } /** - * Writes `data` to a Parquet file and reads it back as a SchemaRDD, which is then passed to `f`. - * The Parquet file will be deleted after `f` returns. + * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ protected def withParquetRDD[T <: Product: ClassTag: TypeTag] (data: Seq[T]) - (f: SchemaRDD => Unit): Unit = { + (f: DataFrame => Unit): Unit = { withParquetFile(data)(path => f(parquetFile(path))) } @@ -112,15 +112,15 @@ trait ParquetTest { } /** - * Writes `data` to a Parquet file, reads it back as a SchemaRDD and registers it as a temporary - * table named `tableName`, then call `f`. The temporary table together with the Parquet file will - * be dropped/deleted after `f` returns. + * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Parquet file will be dropped/deleted after `f` returns. */ protected def withParquetTable[T <: Product: ClassTag: TypeTag] (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetRDD(data) { rdd => - rdd.registerTempTable(tableName) + sqlContext.registerRDDAsTable(rdd, tableName) withTempTable(tableName)(f) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 37853d4d03019..d13f2ce2a5e1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,19 +18,18 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, @@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy { } } + /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => + LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => + GreaterThanOrEqual(a.name, v) case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 171b816a26332..b4af91a768efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SchemaRDD, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand @@ -225,7 +225,8 @@ private [sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName) + sqlContext.registerRDDAsTable( + new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f9c082216085d..906455dd40c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import scala.language.implicitConversions import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ @@ -37,11 +37,11 @@ object TestSQLContext } /** - * Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to - * construct SchemaRDD directly out of local data without relying on implicits. + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to + * construct [[DataFrame]] directly out of local data without relying on implicits. */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = { - new SchemaRDD(this, plan) + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + new DataFrame(this, plan) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java index 9ff40471a00af..e5588938ea162 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -61,7 +61,7 @@ public Integer call(String str) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); assert(result.getInt(0) == 4); } @@ -81,7 +81,7 @@ public Integer call(String str1, String str2) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); assert(result.getInt(0) == 9); } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9e96738ac095a..badd00d34b9b1 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,8 +98,8 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema); - schemaRDD.registerTempTable("people"); + DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema); + df.registerTempTable("people"); Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -147,17 +147,17 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); - StructType actualSchema1 = schemaRDD1.schema(); + DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); + StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerTempTable("jsonTable1"); + df1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); - StructType actualSchema2 = schemaRDD2.schema(); + DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); + StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD2.registerTempTable("jsonTable2"); + df2.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java new file mode 100644 index 0000000000000..639436368c4a3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import com.google.common.collect.ImmutableMap; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.types.DataTypes; + +import static org.apache.spark.sql.Dsl.*; + +/** + * This test doesn't actually run anything. It is here to check the API compatibility for Java. + */ +public class JavaDsl { + + public static void testDataFrame(final DataFrame df) { + DataFrame df1 = df.select("colA"); + df1 = df.select("colA", "colB"); + + df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); + + df1 = df.filter(col("colA")); + + java.util.Map aggExprs = ImmutableMap.builder() + .put("colA", "sum") + .put("colB", "avg") + .build(); + + df1 = df.agg(aggExprs); + + df1 = df.groupBy("groupCol").agg(aggExprs); + + df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); + + df.orderBy("colA"); + df.orderBy("colA", "colB", "colC"); + df.orderBy(col("colA").desc()); + df.orderBy(col("colA").desc(), col("colB").asc()); + + df.sort("colA"); + df.sort("colA", "colB", "colC"); + df.sort(col("colA").desc()); + df.sort(col("colA").desc(), col("colB").asc()); + + df.as("b"); + + df.limit(5); + + df.unionAll(df1); + df.intersect(df1); + df.except(df1); + + df.sample(true, 0.1, 234); + + df.head(); + df.head(5); + df.first(); + df.count(); + } + + public static void testColumn(final Column c) { + c.asc(); + c.desc(); + + c.endsWith("abcd"); + c.startsWith("afgasdf"); + + c.like("asdf%"); + c.rlike("wef%asdf"); + + c.as("newcol"); + + c.cast("int"); + c.cast(DataTypes.IntegerType); + } + + public static void testDsl() { + // Creating a column. + Column c = col("abcd"); + Column c1 = column("abcd"); + + // Literals + Column l1 = lit(1); + Column l2 = lit(1.0); + Column l3 = lit("abcd"); + + // Functions + Column a = upper(c); + a = lower(c); + a = sqrt(c); + a = abs(c); + + // Aggregates + a = min(c); + a = max(c); + a = sum(c); + a = sumDistinct(c); + a = countDistinct(c, a); + a = avg(c); + a = first(c); + a = last(c); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index cfc037caff2a9..c9221f8f934ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -50,17 +51,17 @@ class CachedTableSuite extends QueryTest { } test("unpersist an uncached table will not raise exception") { - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(true) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(false) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != lookupCachedData(testData)) + assert(None != cacheManager.lookupCachedData(testData)) testData.unpersist(true) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(false) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) } test("cache table as select") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala new file mode 100644 index 0000000000000..2d464c2b53d79 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} + + +class ColumnExpressionSuite extends QueryTest { + import org.apache.spark.sql.TestData._ + + // TODO: Add test cases for bitwise operations. + + test("star") { + checkAnswer(testData.select($"*"), testData.collect().toSeq) + } + + test("star qualified by data frame object") { + // This is not yet supported. + val df = testData.toDataFrame + val goldAnswer = df.collect().toSeq + checkAnswer(df.select(df("*")), goldAnswer) + + val df1 = df.select(df("*"), lit("abcd").as("litCol")) + checkAnswer(df1.select(df("*")), goldAnswer) + } + + test("star qualified by table name") { + checkAnswer(testData.as("testData").select($"testData.*"), testData.collect().toSeq) + } + + test("+") { + checkAnswer( + testData2.select($"a" + 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + 1))) + + checkAnswer( + testData2.select($"a" + $"b" + 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + r.getInt(1) + 2))) + } + + test("-") { + checkAnswer( + testData2.select($"a" - 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - 1))) + + checkAnswer( + testData2.select($"a" - $"b" - 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - r.getInt(1) - 2))) + } + + test("*") { + checkAnswer( + testData2.select($"a" * 10), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * 10))) + + checkAnswer( + testData2.select($"a" * $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * r.getInt(1)))) + } + + test("/") { + checkAnswer( + testData2.select($"a" / 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / 2))) + + checkAnswer( + testData2.select($"a" / $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / r.getInt(1)))) + } + + + test("%") { + checkAnswer( + testData2.select($"a" % 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % 2))) + + checkAnswer( + testData2.select($"a" % $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % r.getInt(1)))) + } + + test("unary -") { + checkAnswer( + testData2.select(-$"a"), + testData2.collect().toSeq.map(r => Row(-r.getInt(0)))) + } + + test("unary !") { + checkAnswer( + complexData.select(!$"b"), + complexData.collect().toSeq.map(r => Row(!r.getBoolean(3)))) + } + + test("isNull") { + checkAnswer( + nullStrings.toDataFrame.where($"s".isNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + } + + test("isNotNull") { + checkAnswer( + nullStrings.toDataFrame.where($"s".isNotNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + } + + test("===") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("<=>") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("!==") { + val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + Row(1, 1) :: + Row(1, 2) :: + Row(1, null) :: + Row(null, null) :: Nil), + StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) + + checkAnswer( + nullData.filter($"b" <=> 1), + Row(1, 1) :: Nil) + + checkAnswer( + nullData.filter($"b" <=> null), + Row(1, null) :: Row(null, null) :: Nil) + + checkAnswer( + nullData.filter($"a" <=> $"b"), + Row(1, 1) :: Row(null, null) :: Nil) + } + + test(">") { + checkAnswer( + testData2.filter($"a" > 1), + testData2.collect().toSeq.filter(r => r.getInt(0) > 1)) + + checkAnswer( + testData2.filter($"a" > $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1))) + } + + test(">=") { + checkAnswer( + testData2.filter($"a" >= 1), + testData2.collect().toSeq.filter(r => r.getInt(0) >= 1)) + + checkAnswer( + testData2.filter($"a" >= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1))) + } + + test("<") { + checkAnswer( + testData2.filter($"a" < 2), + testData2.collect().toSeq.filter(r => r.getInt(0) < 2)) + + checkAnswer( + testData2.filter($"a" < $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) < r.getInt(1))) + } + + test("<=") { + checkAnswer( + testData2.filter($"a" <= 2), + testData2.collect().toSeq.filter(r => r.getInt(0) <= 2)) + + checkAnswer( + testData2.filter($"a" <= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) + } + + val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + + test("&&") { + checkAnswer( + booleanData.filter($"a" && true), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" && false), + Nil) + + checkAnswer( + booleanData.filter($"a" && $"b"), + Row(true, true) :: Nil) + } + + test("||") { + checkAnswer( + booleanData.filter($"a" || true), + booleanData.collect()) + + checkAnswer( + booleanData.filter($"a" || false), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" || $"b"), + Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) + } + + test("sqrt") { + checkAnswer( + testData.select(sqrt('key)).orderBy('key.asc), + (1 to 100).map(n => Row(math.sqrt(n))) + ) + + checkAnswer( + testData.select(sqrt('value), 'key).orderBy('key.asc, 'value.asc), + (1 to 100).map(n => Row(math.sqrt(n), n)) + ) + + checkAnswer( + testData.select(sqrt(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("abs") { + checkAnswer( + testData.select(abs('key)).orderBy('key.asc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + negativeData.select(abs('key)).orderBy('key.desc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + testData.select(abs(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("upper") { + checkAnswer( + lowerCaseData.select(upper('l)), + ('a' to 'd').map(c => Row(c.toString.toUpperCase)) + ) + + checkAnswer( + testData.select(upper('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(upper(lit(null))), + (1 to 100).map(n => Row(null)) + ) + } + + test("lower") { + checkAnswer( + upperCaseData.select(lower('L)), + ('A' to 'F').map(c => Row(c.toString.toLowerCase)) + ) + + checkAnswer( + testData.select(lower('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(lower(lit(null))), + (1 to 100).map(n => Row(null)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala similarity index 52% rename from sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afbfe214f1ce4..df343adc793bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types._ /* Implicits */ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import scala.language.postfixOps -class DslQuerySuite extends QueryTest { +class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ test("table scan") { @@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupBy('a)('a, sum('b)), + testData2.groupBy("a").agg($"a", sum($"b")), Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), + testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) checkAnswer( - testData2.aggregate(sum('b)), + testData2.agg(sum('b)), Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( - testData.where($"key" === 1).select($"value"), + testData.where($"key" === lit(1)).select($"value"), Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === lit(1)).select('value), Row("1")) } test("select *") { checkAnswer( - testData.select(Star(None)), + testData.select($"*"), testData.collect().toSeq) } test("simple select") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === lit(1)).select('value), Row("1")) } test("select with functions") { checkAnswer( - testData.select(sum('value), avg('value), count(1)), + testData.select(sum('value), avg('value), count(lit(1))), Row(5050.0, 50.5, 100)) checkAnswer( @@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) - } - - test("partition wide sorting") { - // 2 partitions totally, and - // Partition #1 with values: - // (1, 1) - // (1, 2) - // (2, 1) - // Partition #2 with values: - // (2, 2) - // (3, 1) - // (3, 2) - checkAnswer( - testData2.sortBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) - - checkAnswer( - testData2.sortBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.desc), - Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.asc), - Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest { mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } - test("SPARK-3395 limit distinct") { - val filtered = TestData.testData2 - .distinct() - .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending)) - .limit(1) - .registerTempTable("onerow") - checkAnswer( - sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: Nil) - } - - test("SPARK-3858 generator qualifiers are discarded") { - checkAnswer( - arrayData.as('ad) - .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) - .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Row(_))) - } - test("average") { checkAnswer( - testData2.aggregate(avg('a)), + testData2.agg(avg('a)), Row(2.0)) checkAnswer( - testData2.aggregate(avg('a), sumDistinct('a)), // non-partial + testData2.agg(avg('a), sumDistinct('a)), // non-partial Row(2.0, 6.0) :: Nil) checkAnswer( - decimalData.aggregate(avg('a)), + decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial + decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2))), + decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( - testData3.aggregate(avg('b)), + testData3.agg(avg('b)), Row(2.0)) checkAnswer( - testData3.aggregate(avg('b), countDistinct('b)), + testData3.agg(avg('b), countDistinct('b)), Row(2.0, 1)) checkAnswer( - testData3.aggregate(avg('b), sumDistinct('b)), // non-partial + testData3.agg(avg('b), sumDistinct('b)), // non-partial Row(2.0, 2.0)) } test("zero average") { checkAnswer( - emptyTableData.aggregate(avg('a)), + emptyTableData.agg(avg('a)), Row(null)) checkAnswer( - emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial Row(null, null)) } @@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest { assert(testData2.count() === testData2.map(_ => 1).count()) checkAnswer( - testData2.aggregate(count('a), sumDistinct('a)), // non-partial + testData2.agg(count('a), sumDistinct('a)), // non-partial Row(6, 6.0)) } test("null count") { checkAnswer( - testData3.groupBy('a)('a, count('b)), + testData3.groupBy('a).agg('a, count('b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.groupBy('a)('a, count('a + 'b)), + testData3.groupBy('a).agg('a, count('a + 'b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), + testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), Row(2, 1, 2, 2, 1) ) checkAnswer( - testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial Row(1, 1, 2) ) } @@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest { assert(emptyTableData.count() === 0) checkAnswer( - emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("zero sum") { checkAnswer( - emptyTableData.aggregate(sum('a)), + emptyTableData.agg(sum('a)), Row(null)) } test("zero sum distinct") { checkAnswer( - emptyTableData.aggregate(sumDistinct('a)), + emptyTableData.agg(sumDistinct('a)), Row(null)) } @@ -320,76 +271,14 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select(Star(None), foo.call('key, 'value)).limit(3), + testData.select($"*", callUDF(foo, 'key, 'value)).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } - test("sqrt") { - checkAnswer( - testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Row(math.sqrt(n))) - ) - - checkAnswer( - testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Row(math.sqrt(n), n)) - ) - - checkAnswer( - testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("upper") { - checkAnswer( - lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase())) - ) - - checkAnswer( - testData.select(upper('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(upper(Literal(null))), - (1 to 100).map(n => Row(null)) - ) + test("apply on query results (SPARK-5462)") { + val df = testData.sqlContext.sql("select key from testData") + checkAnswer(df("key"), testData.select('key).collect().toSeq) } - test("lower") { - checkAnswer( - upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase())) - ) - - checkAnswer( - testData.select(lower('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(lower(Literal(null))), - (1 to 100).map(n => Row(null)) - ) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index cd36da7751e83..f0c939dbb195f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -58,7 +59,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - clearCache() + cacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -92,7 +93,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted hash join operator selection") { - clearCache() + cacheManager.clearCache() sql("CACHE TABLE testData") Seq( @@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } test("inner join where, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + upperCaseData.join(lowerCaseData).where('n === 'N), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("inner join ON, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 1).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Row(1,1,1,1) :: Row(1,1,1,2) :: Row(1,2,1,1) :: @@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 2).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Nil) } test("big inner join, 4 matches per row") { val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + bigDataX.join(bigDataY).where($"x.key" === $"y.key"), + testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { @@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left outer join") { checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("right outer join") { checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -306,7 +305,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val right = UnresolvedRelation(Seq("right"), None) checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + left.join(right, $"left.N" === $"right.N", "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", 3, "C") :: @@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -385,7 +384,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - clearCache() + cacheManager.clearCache() sql("CACHE TABLE testData") val tmp = conf.autoBroadcastJoinThreshold diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 42a21c148df53..a7f2faa3ecf75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -26,12 +26,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -44,10 +44,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -91,7 +91,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } @@ -101,8 +101,10 @@ class QueryTest extends PlanTest { } } - /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + /** + * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + */ + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d6695..d684278f11bcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ + class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -184,6 +186,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1,3), Row(2,3), Row(3,3))) } + test("literal in agg grouping expressions") { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1,2), Row(2,2), Row(3,2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1,2), Row(2,2), Row(3,2))) + } + test("aggregates with nulls") { checkAnswer( sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), @@ -381,8 +392,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("big inner join, 4 matches per row") { - - checkAnswer( sql( """ @@ -396,7 +405,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData UNION ALL | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), - testData.flatMap( + testData.rdd.flatMap( row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } @@ -651,8 +660,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerTempTable("applySchema1") + val df1 = applySchema(rowRDD1, schema1) + df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: @@ -681,8 +690,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val schemaRDD2 = applySchema(rowRDD2, schema2) - schemaRDD2.registerTempTable("applySchema2") + val df2 = applySchema(rowRDD2, schema2) + df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: @@ -706,8 +715,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val schemaRDD3 = applySchema(rowRDD3, schema2) - schemaRDD3.registerTempTable("applySchema3") + val df3 = applySchema(rowRDD3, schema2) + df3.registerTempTable("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), @@ -742,7 +751,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("metadata is propagated correctly") { - val person = sql("SELECT * FROM person") + val person: DataFrame = sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -751,14 +760,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person, schemaWithMeta) - def validateMetadata(rdd: SchemaRDD): Unit = { + val personWithMeta = applySchema(person.rdd, schemaWithMeta) + def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } personWithMeta.registerTempTable("personWithMeta") - validateMetadata(personWithMeta.select('name)) - validateMetadata(personWithMeta.select("name".attr)) - validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"id", $"name")) validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 808ed5288cfb8..dd781169ca57f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test._ /* Implicits */ @@ -29,11 +30,11 @@ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD + (1 to 100).map(i => TestData(i, i.toString))).toDataFrame testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD + (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -44,7 +45,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD + LargeAndSmallInts(3, 2) :: Nil).toDataFrame largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -55,7 +56,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toDataFrame testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -67,7 +68,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toSchemaRDD + DecimalData(3, 2) :: Nil).toDataFrame decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -77,17 +78,17 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD + BinaryData("123".getBytes(), 4) :: Nil).toDataFrame binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toSchemaRDD + TestData3(2, Some(2)) :: Nil).toDataFrame testData3.registerTempTable("testData3") - val emptyTableData = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation($"a".int, $"b".int) case class UpperCaseData(N: Int, L: String) val upperCaseData = @@ -97,7 +98,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toSchemaRDD + UpperCaseData(6, "F") :: Nil).toDataFrame upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -106,7 +107,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toSchemaRDD + LowerCaseData(4, "d") :: Nil).toDataFrame lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -160,7 +161,7 @@ object TestData { TestSQLContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil) + NullStrings(3, null) :: Nil).toDataFrame nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) @@ -200,6 +201,6 @@ object TestData { TestSQLContext.sparkContext.parallelize( ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toSchemaRDD + :: Nil).toDataFrame complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0c98120031242..95923f9aad931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.Dsl.StringToColumn import org.apache.spark.sql.test._ /* Implicits */ @@ -28,25 +29,25 @@ class UDFSuite extends QueryTest { test("Simple UDF") { udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) - val result= + val result = sql("SELECT returnStruct('test', 'test2') as ret") - .select("ret.f1".attr).first().getString(0) - assert(result == "test") + .select($"ret.f1").head().getString(0) + assert(result === "test") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fbc8704f7837b..0696a2335e63f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest { test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index e61f3c39631da..3d33484ab0eb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index c3a3f8ddc3ebf..fe9a69edbb920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -104,14 +104,14 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val schemaRdd = sql(query) - val queryExecution = schemaRdd.queryExecution + val df = sql(query) + val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { - schemaRdd.collect().map(_(0)).toArray + df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 67007b8c093ca..df108a9d262bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } @@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite { } test("count distinct is partially aggregated") { - val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } @@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite { testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") - val a = testData.as('a) - val b = table("tiny").as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + val a = testData.as("a") + val b = table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala deleted file mode 100644 index 272c0d4cb2335..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ - -/* Implicit conversions */ -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns - * from the input data. These will be replaced during analysis with specific AttributeReferences - * and then bound to specific ordinals during query planning. While TGFs could also access specific - * columns using hand-coded ordinals, doing so violates data independence. - * - * Note: this is only a rough example of how TGFs can be expressed, the final version will likely - * involve a lot more sugar for cleaner use in Scala/Java/etc. - */ -case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { - def children = input - protected def makeOutput() = 'nameAndAge.string :: Nil - - val Seq(nameAttr, ageAttr) = input - - override def eval(input: Row): TraversableOnce[Row] = { - val name = nameAttr.eval(input) - val age = ageAttr.eval(input).asInstanceOf[Int] - - Iterator( - new GenericRow(Array[Any](s"$name is $age years old")), - new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) - } -} - -class TgfSuite extends QueryTest { - val inputData = - logical.LocalRelation('name.string, 'age.int).loadData( - ("michael", 29) :: Nil - ) - - test("simple tgf example") { - checkAnswer( - inputData.generate(ExampleTGF()), - Seq( - Row("michael is 29 years old"), - Row("Next year, michael will be 30 years old"))) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 87c28c334d228..4e9472c60249e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ class DebuggingSuite extends FunSuite { - test("SchemaRDD.debug()") { + test("DataFrame.debug()") { testData.debug() } - test("SchemaRDD.typeCheck()") { + test("DataFrame.typeCheck()") { testData.typeCheck() } } \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 94d14acccbb18..cb615388da0c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.{QueryTest, Row, SQLConf} class JsonSuite extends QueryTest { import org.apache.spark.sql.json.TestJsonData._ + TestJsonData test("Type promotion") { @@ -193,7 +195,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonSchemaRDD = jsonRDD(jsonNullStruct) + val jsonDF = jsonRDD(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -202,8 +204,8 @@ class JsonSuite extends QueryTest { StructField("ip", StringType, true) :: StructField("nullstr", StringType, true):: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + assert(expectedSchema === jsonDF.schema) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select nullstr, headers.Host from jsonTable"), @@ -212,7 +214,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + val jsonDF = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -223,9 +225,9 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -240,7 +242,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType1) + val jsonDF = jsonRDD(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: @@ -264,9 +266,9 @@ class JsonSuite extends QueryTest { StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -339,8 +341,8 @@ class JsonSuite extends QueryTest { } ignore("Complex field and type inferring (Ignored)") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType1) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType1) + jsonDF.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( @@ -357,7 +359,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -367,9 +369,9 @@ class JsonSuite extends QueryTest { StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -428,8 +430,8 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -462,9 +464,9 @@ class JsonSuite extends QueryTest { // We should directly cast num_str to DecimalType and also need to do the right type promotion // in the Project. checkAnswer( - jsonSchemaRDD. + jsonDF. where('num_str > BigDecimal("92233720368547758060")). - select('num_str + 1.2 as Symbol("num")), + select(('num_str + 1.2).as("num")), Row(new java.math.BigDecimal("92233720368547758061.2")) ) @@ -481,7 +483,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) + val jsonDF = jsonRDD(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(IntegerType, false), true) :: @@ -491,9 +493,9 @@ class JsonSuite extends QueryTest { StructField("field", StringType, true) :: Nil), true) :: StructField("struct_array", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -505,7 +507,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in array elements") { - val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) + val jsonDF = jsonRDD(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -513,9 +515,9 @@ class JsonSuite extends QueryTest { StructField("field", LongType, true) :: Nil), false), true) :: StructField("array3", ArrayType(StringType, false), true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -533,7 +535,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonSchemaRDD = jsonRDD(missingFields) + val jsonDF = jsonRDD(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -543,16 +545,16 @@ class JsonSuite extends QueryTest { StructField("field", BooleanType, true) :: Nil), true) :: StructField("e", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") } test("Loading a JSON dataset from a text file") { val file = getTempFilePath("json") val path = file.toString primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonSchemaRDD = jsonFile(path) + val jsonDF = jsonFile(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -563,9 +565,9 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -619,11 +621,11 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonSchemaRDD1 = jsonFile(path, schema) + val jsonDF1 = jsonFile(path, schema) - assert(schema === jsonSchemaRDD1.schema) + assert(schema === jsonDF1.schema) - jsonSchemaRDD1.registerTempTable("jsonTable1") + jsonDF1.registerTempTable("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -636,11 +638,11 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) - assert(schema === jsonSchemaRDD2.schema) + assert(schema === jsonDF2.schema) - jsonSchemaRDD2.registerTempTable("jsonTable2") + jsonDF2.registerTempTable("jsonTable2") checkAnswer( sql("select * from jsonTable2"), @@ -655,8 +657,8 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType2) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), @@ -673,8 +675,8 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType2) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql( @@ -696,8 +698,8 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonSchemaRDD = jsonRDD(jsonArray) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(jsonArray) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql( @@ -717,8 +719,8 @@ class JsonSuite extends QueryTest { val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonSchemaRDD = jsonRDD(corruptRecords) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(corruptRecords) + jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("_unparsed", StringType, true) :: @@ -726,7 +728,7 @@ class JsonSuite extends QueryTest { StructField("b", StringType, true) :: StructField("c", StringType, true) :: Nil) - assert(schema === jsonSchemaRDD.schema) + assert(schema === jsonDF.schema) // In HiveContext, backticks should be used to access columns starting with a underscore. checkAnswer( @@ -771,8 +773,8 @@ class JsonSuite extends QueryTest { } test("SPARK-4068: nulls in arrays") { - val jsonSchemaRDD = jsonRDD(nullsInArrays) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(nullsInArrays) + jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("field1", @@ -786,7 +788,7 @@ class JsonSuite extends QueryTest { StructField("field4", ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) - assert(schema === jsonSchemaRDD.schema) + assert(schema === jsonDF.schema) checkAnswer( sql( @@ -801,7 +803,7 @@ class JsonSuite extends QueryTest { ) } - test("SPARK-4228 SchemaRDD to JSON") + test("SPARK-4228 DataFrame to JSON") { val schema1 = StructType( StructField("f1", IntegerType, false) :: @@ -818,10 +820,10 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerTempTable("applySchema1") - val schemaRDD2 = schemaRDD1.toSchemaRDD - val result = schemaRDD2.toJSON.collect() + val df1 = applySchema(rowRDD1, schema1) + df1.registerTempTable("applySchema1") + val df2 = df1.toDataFrame + val result = df2.toJSON.collect() assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") @@ -839,16 +841,16 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val schemaRDD3 = applySchema(rowRDD2, schema2) - schemaRDD3.registerTempTable("applySchema2") - val schemaRDD4 = schemaRDD3.toSchemaRDD - val result2 = schemaRDD4.toJSON.collect() + val df3 = applySchema(rowRDD2, schema2) + df3.registerTempTable("applySchema2") + val df4 = df3.toDataFrame + val result2 = df4.toJSON.collect() assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonSchemaRDD.toJSON) + val jsonDF = jsonRDD(primitiveFieldAndType) + val primTable = jsonRDD(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -860,8 +862,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonSchemaRDD.toJSON) + val complexJsonDF = jsonRDD(complexFieldAndType1) + val compTable = jsonRDD(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 1e7d3e06fc196..e78145f4dda5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -21,9 +21,10 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} +import org.apache.spark.sql.types._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -41,15 +42,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( - rdd: SchemaRDD, + rdd: DataFrame, predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Seq[Row]) => Unit, + checker: (DataFrame, Seq[Row]) => Unit, expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output: _*).where(predicate) + val query = rdd + .select(output.map(e => new org.apache.spark.sql.Column(e)): _*) + .where(new org.apache.spark.sql.Column(predicate)) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -71,13 +74,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { private def checkFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } private def checkFilterPredicate[T] (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -91,26 +94,50 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } + test("filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, + classOf[Operators.And], 3) + checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -118,24 +145,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - long") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -143,24 +170,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - float") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -168,24 +195,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - double") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -197,30 +224,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + (implicit rdd: DataFrame): Unit = { + def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } @@ -231,7 +258,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -249,16 +276,16 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate( '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index a57e4e85a35ef..d9ab16baf9a66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -32,12 +32,13 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -97,11 +98,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): SchemaRDD = + def makeDecimalRDD(decimal: DecimalType): DataFrame = sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select('_1 cast decimal) + .select($"_1" cast decimal as "abcd") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 1263ff818ea19..3d82f4bce7778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -85,4 +85,15 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 7900b3e8948d9..a33cf1172cac9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import scala.language.existentials + import org.apache.spark.sql._ import org.apache.spark.sql.types._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7385952861ee5..bb19ac232fcbe 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -23,6 +23,7 @@ import java.io._ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} + import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 166c56b9dfe20..ea9d61d8d0f5e 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation( sessionToActivePool: SMap[SessionHandle, String]) extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index eaf7a1ddd4996..71e3954b2c7ac 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation( // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9d2cfd8e0d669..b746942cb1067 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -64,15 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) - override def sql(sqlText: String): SchemaRDD = { + override def sql(sqlText: String): DataFrame = { val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) + new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } @@ -352,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] val planner = hivePlanner /** Extends QueryExecution with hive specific features. */ - protected[sql] abstract class QueryExecution extends super.QueryExecution { + protected[sql] class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { /** * Returns the result as a hive compatible sequence of strings. For native commands, the diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 5e29e57d93585..399e58b259a45 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1002,11 +1002,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => Star(None) + case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - Star(Some(name)) + UnresolvedStar(Some(name)) /* Aggregate Functions */ case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) @@ -1145,7 +1145,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(name, Nil) :: args) => UnresolvedFunction(name, args.map(nodeToExpr)) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, Star(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ case Token("TOK_NULL", Nil) => Literal(null, NullType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6952b126cf894..ace9329cd5821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} +import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -55,16 +55,15 @@ private[hive] trait HiveStrategies { */ @Experimental object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: SchemaRDD) { - def lowerCase = - new SchemaRDD(s.sqlContext, s.logicalPlan) + implicit class LogicalPlanHacks(s: DataFrame) { + def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s } else { - new SchemaRDD( + new DataFrame( s.sqlContext, s.logicalPlan transform { case p: ParquetRelation => p.copy(partitioningAttributes = attrs) @@ -97,13 +96,13 @@ private[hive] trait HiveStrategies { // We are going to throw the predicates and projection back at the whole optimization // sequence so lets unresolve all the attributes, allowing them to be rebound to the // matching parquet attributes. - val unresolvedOtherPredicates = otherPredicates.map(_ transform { + val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true)) + }).reduceOption(And).getOrElse(Literal(true))) - val unresolvedProjection = projectList.map(_ transform { + val unresolvedProjection: Seq[Column] = projectList.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }) + }).map(new Column(_)) try { if (relation.hiveQlTable.isPartitioned) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 47431cef03e13..7c1d1133c3425 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -68,6 +68,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { System.clearProperty("spark.hostPort") CommandProcessorFactory.clean(hiveconf) + hiveconf.set("hive.plan.serialization.format", "javaXML") + lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath @@ -99,7 +101,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -150,8 +152,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(hql) + protected[hive] class HiveQLQueryExecution(hql: String) + extends this.QueryExecution(HiveQl.parseSql(hql)) { def hiveExec() = runSqlHive(hql) override def toString = hql + "\n" + super.toString } @@ -159,7 +161,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** * Override QueryExecution with special debug workflow. */ - abstract class QueryExecution extends super.QueryExecution { + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { override lazy val analyzed = { val describedTables = logical match { case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil @@ -395,7 +398,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } - clearCache() + cacheManager.clearCache() loadedTables.clear() catalog.cachedDataSourceTables.invalidateAll() catalog.client.getAllTables("default").foreach { t => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 42bc8a0b67933..91af35f0965c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -239,7 +239,7 @@ case class InsertIntoHiveTable( } // Invalidate the cache. - sqlContext.invalidateCache(table) + sqlContext.cacheManager.invalidateCache(table) // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 91f9da35abeee..4814cb7ebfe51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -54,7 +54,7 @@ case class DropTable( val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { - hiveContext.tryUncacheQuery(hiveContext.table(tableName)) + hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { // This table's metadata is not in case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java index 8b29fa7d1a8f7..4b23fbf6e7362 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java @@ -15,4 +15,4 @@ * limitations under the License. */ -package org.apache.spark.sql.hive; \ No newline at end of file +package org.apache.spark.sql.hive; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f320d732fb77a..ba391293884bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -36,12 +36,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -54,10 +54,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -101,7 +101,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f95a6b43af357..61e5117feab10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { @@ -28,7 +28,7 @@ class CachedTableSuite extends QueryTest { * Throws a test failed exception when the number of cached tables differs from the expected * number. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 0e6636d38ed3c..4dd96bd5a1b77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq + testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq ) // Now overwrite. @@ -82,8 +82,8 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithMapValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -127,8 +127,8 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithArrayValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -144,8 +144,8 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithMapValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -161,8 +161,8 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithStructValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index df72be7746ac6..60619f5d99578 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,11 +27,12 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) @@ -367,7 +368,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } - test("SchemaRDD toString") { + test("DataFrame toString") { sql("SHOW TABLES").toString sql("SELECT * FROM src").toString } @@ -473,12 +474,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: SchemaRDD) = { + def isExplanation(result: DataFrame) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } - test("SPARK-1704: Explain commands as a SchemaRDD") { + test("SPARK-1704: Explain commands as a DataFrame") { sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") val rdd = sql("explain select key, count(value) from src group by key") @@ -508,6 +509,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("SPARK-5367: resolve star expression in udf") { + assert(sql("select concat(*) from src limit 5").collect().size == 5) + assert(sql("select array(*) from src limit 5").collect().size == 5) + } + test("Query Hive native command execution result") { val tableName = "test_native_commands" @@ -842,7 +848,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = + def collectResults(rdd: DataFrame): Set[(String, String)] = rdd.collect().map { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 16f77a438e1ae..8fb5e050a237a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.Row import org.apache.spark.util.Utils @@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest { sql("create table spark_4959 (col1 string)") sql("""insert into table spark_4959 select "hi" from src limit 1""") table("spark_4959").select( - 'col1.as('CaseSensitiveColName), - 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + 'col1.as("CaseSensitiveColName"), + 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f2374a215291b..dd0df1a9f6320 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f3, | getStruct(1).f4, | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) } test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 7f9f1ac7cd80d..eb7a7750af02d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT distinct key FROM src order by key").collect().toSeq) } - test("SPARK-4963 SchemaRDD sample on mutable row return wrong result") { + test("SPARK-4963 DataFrame sample on mutable row return wrong result") { sql("SELECT * FROM src WHERE key % 2 = 0") .sample(withReplacement = false, fraction = 0.3) .registerTempTable("sampled") @@ -267,4 +267,19 @@ class SQLQuerySuite extends QueryTest { sql("DROP TABLE nullValuesInInnerComplexTypes") dropTempTable("testTable") } + + test("SPARK-4296 Grouping field with Hive UDF as sub expression") { + val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + jsonRDD(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), + Row("str-1", 1970)) + + dropTempTable("data") + + jsonRDD(rdd).registerTempTable("data") + checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) + + dropTempTable("data") + } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 22b0d714b57f6..d032491e2ff83 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -40,6 +40,10 @@ spark-core_${scala.binary.version} ${project.version} + + com.google.guava + guava + org.eclipse.jetty jetty-server @@ -95,6 +99,14 @@ + + + org.apache.maven.plugins + maven-shade-plugin + + true + + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 40434b1f9b709..6500608bba87c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -28,21 +28,16 @@ import java.io.File */ class FailureSuite extends TestSuiteBase with Logging { - var directory = "FailureSuite" + val directory = Utils.createTempDir().getAbsolutePath val numBatches = 30 override def batchDuration = Milliseconds(1000) override def useManualClock = false - override def beforeFunction() { - super.beforeFunction() - Utils.deleteRecursively(new File(directory)) - } - override def afterFunction() { - super.afterFunction() Utils.deleteRecursively(new File(directory)) + super.afterFunction() } test("multiple failures with map") { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 902bdda59860e..d3e327b2497b7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -43,8 +43,11 @@ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster(args: ApplicationMasterArguments, - client: YarnRMClient) extends Logging { +private[spark] class ApplicationMaster( + args: ApplicationMasterArguments, + client: YarnRMClient) + extends Logging { + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -231,6 +234,24 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, reporterThread = launchReporterThread() } + /** + * Create an actor that communicates with the driver. + * + * In cluster mode, the AM and the driver belong to same process + * so the AM actor need not monitor lifecycle of the driver. + */ + private def runAMActor( + host: String, + port: String, + isDriver: Boolean): Unit = { + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + host, + port, + YarnSchedulerBackend.ACTOR_NAME) + actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isDriver)), name = "YarnAM") + } + private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() userClassThread = startUserClass() @@ -245,6 +266,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { + actorSystem = sc.env.actorSystem + runAMActor( + sc.getConf.get("spark.driver.host"), + sc.getConf.get("spark.driver.port"), + isDriver = true) registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } @@ -253,7 +279,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityMgr)._1 - actor = waitForSparkDriver() + waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -367,7 +393,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } - private def waitForSparkDriver(): ActorRef = { + private def waitForSparkDriver(): Unit = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) @@ -399,12 +425,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - driverHost, - driverPort.toString, - YarnSchedulerBackend.ACTOR_NAME) - actorSystem.actorOf(Props(new AMActor(driverUrl)), name = "YarnAM") + runAMActor(driverHost, driverPort.toString, isDriver = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -462,9 +483,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } /** - * Actor that communicates with the driver in client deploy mode. + * An actor that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String) extends Actor { + private class AMActor(driverUrl: String, isDriver: Boolean) extends Actor { var driver: ActorSelection = _ override def preStart() = { @@ -474,13 +495,21 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // we can monitor Lifecycle Events. driver ! "Hello" driver ! RegisterClusterManager - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + // In cluster mode, the AM can directly monitor the driver status instead + // of trying to deduce it from the lifecycle of the driver's actor + if (!isDriver) { + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } } override def receive = { case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isDriver) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index d00f29665a58f..3849586c6111e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -32,6 +32,8 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -60,6 +62,11 @@ private[yarn] class YarnAllocator( import YarnAllocator._ + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } + // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4bff846123619..4e39c1d58011b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,12 +17,9 @@ package org.apache.spark.deploy.yarn -import java.lang.{Boolean => JBoolean} import java.io.File -import java.util.{Collections, Set => JSet} import java.util.regex.Matcher import java.util.regex.Pattern -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.HashMap @@ -32,7 +29,6 @@ import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} -import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration import org.apache.spark.{SecurityManager, SparkConf} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index be55d26f1cf61..72ec4d6b34af6 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,33 +17,17 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.yarn.util.RackResolver - import org.apache.spark._ import org.apache.spark.deploy.yarn.ApplicationMaster -import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.Utils /** * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of * ApplicationMaster, etc is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnScheduler(sc) { logInfo("Created YarnClusterScheduler") - // Nothing else for now ... initialize application master : which needs a SparkContext to - // determine how to allocate. - // Note that only the first creation of a SparkContext influences (and ideally, there must be - // only one SparkContext, right ?). Subsequent creations are ignored since executors are already - // allocated by then. - - // By default, rack is unknown - override def getRackForHost(hostPort: String): Option[String] = { - val host = Utils.parseHostPort(hostPort)._1 - Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) - } - override def postStartHook() { ApplicationMaster.sparkContextInitialized(sc) super.postStartHook() diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala similarity index 77% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala rename to yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 2fa24cc43325e..4ebf3af12b381 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -19,14 +19,18 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + import org.apache.spark._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils -/** - * This scheduler launches executors through Yarn - by calling into Client to launch the Spark AM. - */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { +private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { + + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = {