diff --git a/core/pom.xml b/core/pom.xml index 868834dd505ef..6cd1965ec37c2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -275,7 +275,7 @@ org.tachyonproject tachyon-client - 0.6.1 + 0.5.0 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index bcf832467f00b..330df1d59a9b1 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,8 +18,6 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} -import java.util.concurrent.atomic.AtomicLong -import java.lang.ThreadLocal import scala.collection.generic.Growable import scala.collection.mutable.Map @@ -109,7 +107,7 @@ class Accumulable[R, T] ( * The typical use of this method is to directly mutate the local value, eg., to add * an element to a Set. */ - def localValue = value_ + def localValue: R = value_ /** * Set the accumulator's value; only allowed on master. @@ -137,7 +135,7 @@ class Accumulable[R, T] ( Accumulators.register(this, false) } - override def toString = if (value_ == null) "null" else value_.toString + override def toString: String = if (value_ == null) "null" else value_.toString } /** @@ -257,22 +255,22 @@ object AccumulatorParam { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // TODO: Add AccumulatorParams for other types, e.g. lists and strings @@ -351,6 +349,7 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any) = "%s".format(partialValue) - def stringifyValue(value: Any) = "%s".format(value) + def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) + + def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 9a7cd4523e5ab..fc8cdde9348ee 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -74,7 +74,7 @@ class ShuffleDependency[K, V, C]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { - override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] + override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] val shuffleId: Int = _rdd.context.newShuffleId() @@ -91,7 +91,7 @@ class ShuffleDependency[K, V, C]( */ @DeveloperApi class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) + override def getParents(partitionId: Int): List[Int] = List(partitionId) } @@ -107,7 +107,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) } else { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e97a7375a267b..91f9ef8ce7185 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -168,7 +168,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - def jobIds = Seq(jobWaiter.jobId) + def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } @@ -276,7 +276,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds = jobs + def jobIds: Seq[Int] = jobs } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 69178da1a7773..715f292f03469 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -65,7 +65,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule super.preStart() } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val unknownExecutor = !scheduler.executorHeartbeatReceived( executorId, taskMetrics, blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7a..c9426c5de23a2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -43,7 +43,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e53a78ead2c0e..b8d244408bc5b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -76,7 +76,7 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { case null => 0 @@ -154,7 +154,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } - def numPartitions = rangeBounds.length + 1 + def numPartitions: Int = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index 55cb25946c2ad..cb2cae185256a 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -28,8 +28,10 @@ import org.apache.spark.util.Utils @DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString + + def value: T = t + + override def toString: String = t.toString private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { out.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2ca19f53d2f07..0c123c96b8d7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -133,7 +133,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } /** Set multiple parameters together */ - def setAll(settings: Traversable[(String, String)]) = { + def setAll(settings: Traversable[(String, String)]): SparkConf = { this.settings.putAll(settings.toMap.asJava) this } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 228ff715fe7cb..a70be16f77eeb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -986,7 +986,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli union(Seq(first) ++ rest) /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables @@ -994,7 +994,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = { val acc = new Accumulator(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) @@ -1006,7 +1006,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the * driver can access the accumulator's `value`. */ - def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = { + def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) + : Accumulator[T] = { val acc = new Accumulator(initialValue, param, Some(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1018,7 +1019,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = { + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1031,7 +1033,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ - def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = { + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) + : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param, Some(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1209,7 +1212,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId) /** The version of Spark on which this application is running. */ - def version = SPARK_VERSION + def version: String = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining @@ -1659,7 +1662,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - def getCheckpointDir = checkpointDir + def getCheckpointDir: Option[String] = checkpointDir /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: Int = { @@ -1900,28 +1903,28 @@ object SparkContext extends Logging { "backward compatibility.", "1.3.0") object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 + def zero(initialValue: Double): Double = 0.0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 + def zero(initialValue: Int): Int = 0 } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0L + def addInPlace(t1: Long, t2: Long): Long = t1 + t2 + def zero(initialValue: Long): Long = 0L } @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + "backward compatibility.", "1.3.0") object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f + def addInPlace(t1: Float, t2: Float): Float = t1 + t2 + def zero(initialValue: Float): Float = 0f } // The following deprecated functions have already been moved to `object RDD` to @@ -1931,18 +1934,18 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = RDD.rddToPairRDDFunctions(rdd) - } @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = + RDD.rddToAsyncRDDActions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]) = { + rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { val kf = implicitly[K => Writable] val vf = implicitly[V => Writable] // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it @@ -1954,16 +1957,17 @@ object SparkContext extends Logging { @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]) = + rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = RDD.rddToOrderedRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = + RDD.doubleRDDToDoubleRDDFunctions(rdd) @deprecated("Replaced by implicit functions in the RDD companion object. This is " + "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = RDD.numericRDDToDoubleRDDFunctions(rdd) // The following deprecated functions have already been moved to `object WritableFactory` to @@ -2134,7 +2138,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_N_REGEX(threads) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt if (threadCount <= 0) { @@ -2146,7 +2150,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - def localCpuCount = Runtime.getRuntime.availableProcessors() + def localCpuCount: Int = Runtime.getRuntime.availableProcessors() // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index edbdda8a0bcb6..34ee3a48f8e74 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray } } diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index c415fe99b105e..fe19f07e32d1b 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -27,9 +27,9 @@ private[spark] object TaskState extends Enumeration { type TaskState = Value - def isFailed(state: TaskState) = (LOST == state) || (FAILED == state) + def isFailed(state: TaskState): Boolean = (LOST == state) || (FAILED == state) - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) + def isFinished(state: TaskState): Boolean = FINISHED_STATES.contains(state) def toMesos(state: TaskState): MesosTaskState = state match { case LAUNCHING => MesosTaskState.TASK_STARTING diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 35b324ba6f573..398ca41e16151 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -107,7 +107,7 @@ private[spark] object TestUtils { private class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { - override def getCharContent(ignoreEncodingErrors: Boolean) = code + override def getCharContent(ignoreEncodingErrors: Boolean): String = code } /** Creates a compiled class with the given name. Class file will be placed in destDir. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 3e9beb670f7ad..18ccd625fc8d1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -179,7 +179,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - override def toString = rdd.toString + override def toString: String = rdd.toString /** Assign a name to this RDD */ def setName(name: String): JavaRDD[T] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 6d6ed693be752..3be6783bba49d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -108,7 +108,7 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env - def statusTracker = new JavaSparkStatusTracker(sc) + def statusTracker: JavaSparkStatusTracker = new JavaSparkStatusTracker(sc) def isLocal: java.lang.Boolean = sc.isLocal diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 71b26737b8c02..8f9647eea9e25 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.java +import java.util.Map.Entry + import com.google.common.base.Optional import java.{util => ju} @@ -30,8 +32,8 @@ private[spark] object JavaUtils { } // Workaround for SPARK-3926 / SI-8911 - def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = - new SerializableMapWrapper(underlying) + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]): SerializableMapWrapper[A, B] + = new SerializableMapWrapper(underlying) // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper, // but implements java.io.Serializable. It can't just be subclassed to make it @@ -40,36 +42,33 @@ private[spark] object JavaUtils { class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) extends ju.AbstractMap[A, B] with java.io.Serializable { self => - override def size = underlying.size + override def size: Int = underlying.size override def get(key: AnyRef): B = try { - underlying get key.asInstanceOf[A] match { - case None => null.asInstanceOf[B] - case Some(v) => v - } + underlying.getOrElse(key.asInstanceOf[A], null.asInstanceOf[B]) } catch { case ex: ClassCastException => null.asInstanceOf[B] } override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] { - def size = self.size + override def size: Int = self.size - def iterator = new ju.Iterator[ju.Map.Entry[A, B]] { + override def iterator: ju.Iterator[ju.Map.Entry[A, B]] = new ju.Iterator[ju.Map.Entry[A, B]] { val ui = underlying.iterator var prev : Option[A] = None - def hasNext = ui.hasNext + def hasNext: Boolean = ui.hasNext - def next() = { - val (k, v) = ui.next + def next(): Entry[A, B] = { + val (k, v) = ui.next() prev = Some(k) new ju.Map.Entry[A, B] { import scala.util.hashing.byteswap32 - def getKey = k - def getValue = v - def setValue(v1 : B) = self.put(k, v1) - override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) - override def equals(other: Any) = other match { + override def getKey: A = k + override def getValue: B = v + override def setValue(v1 : B): B = self.put(k, v1) + override def hashCode: Int = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16) + override def equals(other: Any): Boolean = other match { case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue case _ => false } @@ -81,7 +80,7 @@ private[spark] object JavaUtils { case Some(k) => underlying match { case mm: mutable.Map[A, _] => - mm remove k + mm.remove(k) prev = None case _ => throw new UnsupportedOperationException("remove") diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4c71b69069eb3..19f4c95fcad74 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -54,9 +54,11 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = firstParent.partitions + override def getPartitions: Array[Partition] = firstParent.partitions - override val partitioner = if (preservePartitoning) firstParent.partitioner else None + override val partitioner: Option[Partitioner] = { + if (preservePartitoning) firstParent.partitioner else None + } override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -92,7 +94,7 @@ private[spark] class PythonRDD( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { - def next(): Array[Byte] = { + override def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { _nextObj = read() @@ -175,7 +177,7 @@ private[spark] class PythonRDD( var _nextObj = read() - def hasNext = _nextObj != null + override def hasNext: Boolean = _nextObj != null } new InterruptibleIterator(context, stdoutIterator) } @@ -303,11 +305,10 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Long, Array[Byte])](prev) { - override def getPartitions = prev.partitions - override val partitioner = prev.partitioner - override def compute(split: Partition, context: TaskContext) = +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte])](prev) { + override def getPartitions: Array[Partition] = prev.partitions + override val partitioner: Option[Partitioner] = prev.partitioner + override def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Byte])] = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) @@ -435,7 +436,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, minSplits: Int, - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] @@ -462,7 +463,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -488,7 +489,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, @@ -505,7 +506,7 @@ private[spark] object PythonRDD extends Logging { inputFormatClass: String, keyClass: String, valueClass: String, - conf: Configuration) = { + conf: Configuration): RDD[(K, V)] = { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] @@ -531,7 +532,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration()) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -557,7 +558,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - batchSize: Int) = { + batchSize: Int): JavaRDD[Array[Byte]] = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val rdd = hadoopRDDFromClassNames[K, V, F](sc, @@ -686,7 +687,7 @@ private[spark] object PythonRDD extends Logging { pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { saveAsHadoopFile( pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat", null, null, null, null, new java.util.HashMap(), compressionCodecClass) @@ -711,7 +712,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, confAsMap: java.util.HashMap[String, String], - compressionCodecClass: String) = { + compressionCodecClass: String): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -741,7 +742,7 @@ private[spark] object PythonRDD extends Logging { valueClass: String, keyConverterClass: String, valueConverterClass: String, - confAsMap: java.util.HashMap[String, String]) = { + confAsMap: java.util.HashMap[String, String]): Unit = { val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized) val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) @@ -766,7 +767,7 @@ private[spark] object PythonRDD extends Logging { confAsMap: java.util.HashMap[String, String], keyConverterClass: String, valueConverterClass: String, - useNewAPI: Boolean) = { + useNewAPI: Boolean): Unit = { val conf = PythonHadoopUtil.mapToConf(confAsMap) val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized), keyConverterClass, valueConverterClass, new JavaToWritableConverter) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index fb52a960e0765..257491e90dd66 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -84,7 +84,7 @@ private[spark] object SerDeUtil extends Logging { private var initialized = false // This should be called before trying to unpickle array.array from Python // In cluster mode, this should be put in closure - def initialize() = { + def initialize(): Unit = { synchronized{ if (!initialized) { Unpickler.registerConstructor("array", "array", new ArrayConstructor()) diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index cf289fb3ae39f..8f30ff9202c83 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,38 +18,37 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} +import java.{util => ju} import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + +import org.apache.spark.SparkException import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.{SparkContext, SparkException} /** * A class to test Pyrolite serialization on the Scala side, that will be deserialized * in Python - * @param str - * @param int - * @param double */ case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { def this() = this("", 0, 0.0) - def getStr = str + def getStr: String = str def setStr(str: String) { this.str = str } - def getInt = int + def getInt: Int = int def setInt(int: Int) { this.int = int } - def getDouble = double + def getDouble: Double = double def setDouble(double: Double) { this.double = double } - def write(out: DataOutput) = { + def write(out: DataOutput): Unit = { out.writeUTF(str) out.writeInt(int) out.writeDouble(double) } - def readFields(in: DataInput) = { + def readFields(in: DataInput): Unit = { str = in.readUTF() int = in.readInt() double = in.readDouble() @@ -57,28 +56,28 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten } private[python] class TestInputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Char = { obj.asInstanceOf[IntWritable].get().toChar } } private[python] class TestInputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) } } private[python] class TestOutputKeyConverter extends Converter[Any, Any] { - override def convert(obj: Any) = { + override def convert(obj: Any): Text = { new Text(obj.asInstanceOf[Int].toString) } } private[python] class TestOutputValueConverter extends Converter[Any, Any] { import collection.JavaConversions._ - override def convert(obj: Any) = { + override def convert(obj: Any): DoubleWritable = { new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) } } @@ -86,7 +85,7 @@ private[python] class TestOutputValueConverter extends Converter[Any, Any] { private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable]) private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] { - override def convert(obj: Any) = obj match { + override def convert(obj: Any): DoubleArrayWritable = obj match { case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] => val daw = new DoubleArrayWritable daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_))) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index a5ea478f231d7..12d79f6ed311b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -146,5 +146,5 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Lo } } - override def toString = "Broadcast(" + id + ")" + override def toString: String = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2e..685313ac009ba 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -58,7 +58,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 1444c0dd3d2d6..74ccfa6d3c9a3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -160,7 +160,7 @@ private[broadcast] object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) private def write(id: Long, value: Any) { val file = getFile(id) @@ -222,7 +222,7 @@ private[broadcast] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c7ef02d572a19..cf3ae36f27949 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = new HttpBroadcast[T](value_, isLocal, id) override def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 94142d33369c7..23b02e60338fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -222,7 +222,7 @@ private object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index fb024c12094f2..96d8dd79908c8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -30,7 +30,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 237d26fc6bd0e..65238af2caa24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -38,7 +38,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) - override def preStart() = { + override def preStart(): Unit = { masterActor = context.actorSelection( Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) @@ -118,7 +118,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 53bc62aff7395..5cbac787dceeb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -42,7 +42,7 @@ private[deploy] class ClientArguments(args: Array[String]) { var memory: Int = DEFAULT_MEMORY var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() - def driverOptions = _driverOptions.toSeq + def driverOptions: Seq[String] = _driverOptions.toSeq // kill parameters var driverId: String = "" diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 7f600d89604a2..9db6fd1ac4dbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -101,6 +101,8 @@ private[deploy] object DeployMessages { case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage + case class UnregisterApplication(appId: String) + case class MasterChangeAcknowledged(appId: String) // Master to AppClient @@ -162,7 +164,7 @@ private[deploy] object DeployMessages { Utils.checkHost(host, "Required hostname") assert (port > 0) - def uri = "spark://" + host + ":" + port + def uri: String = "spark://" + host + ":" + port def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 5668b53fc6f4f..a7c89276a045e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -426,7 +426,7 @@ private object SparkDocker { } private class DockerId(val id: String) { - override def toString = id + override def toString: String = id } private object Docker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 458a7c3a455de..dfc5b97e6a6c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} @@ -24,7 +25,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.ExecutorRunner private[deploy] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { + def writeWorkerInfo(obj: WorkerInfo): JObject = { ("id" -> obj.id) ~ ("host" -> obj.host) ~ ("port" -> obj.port) ~ @@ -39,7 +40,7 @@ private[deploy] object JsonProtocol { ("lastheartbeat" -> obj.lastHeartbeat) } - def writeApplicationInfo(obj: ApplicationInfo) = { + def writeApplicationInfo(obj: ApplicationInfo): JObject = { ("starttime" -> obj.startTime) ~ ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ @@ -51,7 +52,7 @@ private[deploy] object JsonProtocol { ("duration" -> obj.duration) } - def writeApplicationDescription(obj: ApplicationDescription) = { + def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ ("memoryperslave" -> obj.memoryPerSlave) ~ @@ -59,14 +60,14 @@ private[deploy] object JsonProtocol { ("command" -> obj.command.toString) } - def writeExecutorRunner(obj: ExecutorRunner) = { + def writeExecutorRunner(obj: ExecutorRunner): JObject = { ("id" -> obj.execId) ~ ("memory" -> obj.memory) ~ ("appid" -> obj.appId) ~ ("appdesc" -> writeApplicationDescription(obj.appDesc)) } - def writeDriverInfo(obj: DriverInfo) = { + def writeDriverInfo(obj: DriverInfo): JObject = { ("id" -> obj.id) ~ ("starttime" -> obj.startTime.toString) ~ ("state" -> obj.state.toString) ~ @@ -74,7 +75,7 @@ private[deploy] object JsonProtocol { ("memory" -> obj.desc.mem) } - def writeMasterState(obj: MasterStateResponse) = { + def writeMasterState(obj: MasterStateResponse): JObject = { ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ ("cores" -> obj.workers.map(_.cores).sum) ~ @@ -87,7 +88,7 @@ private[deploy] object JsonProtocol { ("status" -> obj.status.toString) } - def writeWorkerState(obj: WorkerStateResponse) = { + def writeWorkerState(obj: WorkerStateResponse): JObject = { ("id" -> obj.workerId) ~ ("masterurl" -> obj.masterUrl) ~ ("masterwebuiurl" -> obj.masterWebUiUrl) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e0a32fb65cd51..c2568eb4b60ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -193,7 +193,7 @@ class SparkHadoopUtil extends Logging { * that file. */ def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { - def recurse(path: Path) = { + def recurse(path: Path): Array[FileStatus] = { val (directories, leaves) = fs.listStatus(path).partition(_.isDir) leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4f506be63fe59..660307d19eab4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -777,7 +777,7 @@ private[deploy] object SparkSubmitUtils { } /** A nice function to use in tests as well. Values are dummy strings. */ - def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( + def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 2250d5a28e4ef..6eb73c43470a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -252,7 +252,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S master.startsWith("spark://") && deployMode == "cluster" } - override def toString = { + override def toString: String = { s"""Parsed arguments: | master $master | deployMode $deployMode diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 2d24083a77b73..4f06d7f96c46e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -116,7 +116,7 @@ private[spark] class AppClient( masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true @@ -157,6 +157,7 @@ private[spark] class AppClient( case StopAppClient => markDead("Application has been stopped.") + master ! UnregisterApplication(appId) sender ! true context.stop(self) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index db7c499661319..80c9c13ddec1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -93,7 +93,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def getRunner(operateFun: () => Unit): Runnable = { new Runnable() { - override def run() = Utils.tryOrExit { + override def run(): Unit = Utils.tryOrExit { operateFun() } } @@ -141,7 +141,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } - override def getListing() = applications.values + override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values override def getAppUI(appId: String): Option[SparkUI] = { try { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index af483d560b33e..72f6048239297 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -61,7 +61,7 @@ class HistoryServer( private val appCache = CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(new RemovalListener[String, SparkUI] { - override def onRemoval(rm: RemovalNotification[String, SparkUI]) = { + override def onRemoval(rm: RemovalNotification[String, SparkUI]): Unit = { detachSparkUI(rm.getValue()) } }) @@ -149,14 +149,14 @@ class HistoryServer( * * @return List of all known applications. */ - def getApplicationList() = provider.getListing() + def getApplicationList(): Iterable[ApplicationHistoryInfo] = provider.getListing() /** * Returns the provider configuration to show in the listing page. * * @return A map with the provider's configuration. */ - def getProviderConfig() = provider.getConfig() + def getProviderConfig(): Map[String, String] = provider.getConfig() } @@ -195,9 +195,7 @@ object HistoryServer extends Logging { server.bind() Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run() = { - server.stop() - } + override def run(): Unit = server.stop() }) // Wait until the end of the world... or if the HistoryServer process is manually stopped diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 536aedb6f9fe9..bc5b293379f2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -91,7 +91,7 @@ private[deploy] class ApplicationInfo( } } - private[master] val requestedCores = desc.maxCores.getOrElse(defaultCores) + private val requestedCores = desc.maxCores.getOrElse(defaultCores) private[master] def coresLeft: Int = requestedCores - coresGranted @@ -111,6 +111,10 @@ private[deploy] class ApplicationInfo( endTime = System.currentTimeMillis() } + private[master] def isFinished: Boolean = { + state != ApplicationState.WAITING && state != ApplicationState.RUNNING + } + def duration: Long = { if (endTime != -1) { endTime - startTime diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index d2d30bfd7fcba..32499b3a784a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -48,7 +48,7 @@ private[master] class FileSystemPersistenceEngine( new File(dir + File.separator + name).delete() } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) files.map(deserializeFromFile[T]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1b42121c8db05..9a5d5877da86d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -204,7 +204,7 @@ private[master] class Master( self ! RevokedLeadership } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -339,7 +339,11 @@ private[master] class Master( if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") - appInfo.removeExecutor(exec) + // If an application has already finished, preserve its + // state to display its information properly on the UI + if (!appInfo.isFinished) { + appInfo.removeExecutor(exec) + } exec.worker.removeExecutor(exec) val normalExit = exitStatus == Some(0) @@ -428,6 +432,10 @@ private[master] class Master( if (canCompleteRecovery) { completeRecovery() } } + case UnregisterApplication(applicationId) => + logInfo(s"Received unregister request from application $applicationId") + idToApp.get(applicationId).foreach(finishApplication) + case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 1583bf1f60032..351db8fab2041 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -51,20 +51,27 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial */ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") - def createPersistenceEngine() = { + def createPersistenceEngine(): PersistenceEngine = { logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) } - def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } } private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) extends StandaloneRecoveryModeFactory(conf, serializer) { - def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) - def createLeaderElectionAgent(master: LeaderElectable) = + def createPersistenceEngine(): PersistenceEngine = { + new ZooKeeperPersistenceEngine(conf, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { new ZooKeeperLeaderElectionAgent(master, conf) + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e94aae93e4495..9b3d48c6edc84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -104,7 +104,7 @@ private[spark] class WorkerInfo( "http://" + this.publicAddress + ":" + this.webUiPort } - def setState(state: WorkerState.Value) = { + def setState(state: WorkerState.Value): Unit = { this.state = state } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 1ac6677ad2b6d..a285783f72000 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -46,7 +46,7 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat zk.delete().forPath(WORKING_DIR + "/" + name) } - override def read[T: ClassTag](prefix: String) = { + override def read[T: ClassTag](prefix: String): Seq[T] = { val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) file.map(deserializeFromFile[T]).flatten } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index dee2e4a447c6e..45412a35e9a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -75,16 +75,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use", - "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration") + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps) - - val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node", - "Submitted Time", "User", "State", "Duration") + val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow, - completedApps) + val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class") @@ -95,7 +91,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { // For now we only show driver information if the user has submitted drivers to the cluster. // This is until we integrate the notion of drivers and applications in the UI. - def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + def hasDrivers: Boolean = activeDrivers.length > 0 || completedDrivers.length > 0 val content =
@@ -191,7 +187,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = { + private def appRow(app: ApplicationInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { val killLinkUri = s"app/kill?id=${app.id}&terminate=true" @@ -201,7 +197,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { (kill) } - {app.id} @@ -210,15 +205,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.desc.name} - { - if (active) { - - {app.coresGranted} - - } - } - {if (app.requestedCores == Int.MaxValue) "*" else app.requestedCores} + {app.coresGranted} {Utils.megabytesToString(app.desc.memoryPerSlave)} @@ -230,14 +218,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def activeAppRow(app: ApplicationInfo): Seq[Node] = { - appRow(app, active = true) - } - - private def completeAppRow(app: ApplicationInfo): Seq[Node] = { - appRow(app, active = false) - } - private def driverRow(driver: DriverInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (driver.state == DriverState.RUNNING || diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 27a9eabb1ede7..e0948e16ef354 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -56,8 +56,14 @@ private[deploy] class DriverRunner( private var finalExitCode: Option[Int] = None // Decoupled for testing - def setClock(_clock: Clock) = clock = _clock - def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper + def setClock(_clock: Clock): Unit = { + clock = _clock + } + + def setSleeper(_sleeper: Sleeper): Unit = { + sleeper = _sleeper + } + private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) @@ -155,7 +161,7 @@ private[deploy] class DriverRunner( private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { builder.directory(baseDir) - def initialize(process: Process) = { + def initialize(process: Process): Unit = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") CommandUtils.redirectStream(process.getInputStream, stdout) @@ -169,8 +175,8 @@ private[deploy] class DriverRunner( runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) } - def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, - supervise: Boolean) { + def runCommandWithRetry( + command: ProcessBuilderLike, initialize: Process => Unit, supervise: Boolean): Unit = { // Time to wait between submission retries. var waitSeconds = 1 // A run of this many seconds resets the exponential back-off. @@ -216,8 +222,8 @@ private[deploy] trait ProcessBuilderLike { } private[deploy] object ProcessBuilderLike { - def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { - def start() = processBuilder.start() - def command = processBuilder.command() + def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { + override def start(): Process = processBuilder.start() + override def command: Seq[String] = processBuilder.command() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c1b0a295f9f74..c4c24a7866aa3 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -275,7 +275,7 @@ private[worker] class Worker( } } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 09d866fb0cd90..e0790274d7d3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -50,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index dd19e4947db1e..b5205d4e997ae 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -62,7 +62,7 @@ private[spark] class CoarseGrainedExecutorBackend( .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) } - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 41925f7e97e84..3e47d13f7545d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -33,7 +33,7 @@ private[spark] case object TriggerThreadDump private[spark] class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case TriggerThreadDump => sender ! Utils.getThreadDump() } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 07b152651dedf..06152f16ae618 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,13 +17,10 @@ package org.apache.spark.executor -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark.executor.DataReadMethod.DataReadMethod - import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} /** @@ -44,14 +41,14 @@ class TaskMetrics extends Serializable { * Host's name the task runs on */ private var _hostname: String = _ - def hostname = _hostname + def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ - def executorDeserializeTime = _executorDeserializeTime + def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value @@ -59,14 +56,14 @@ class TaskMetrics extends Serializable { * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ - def executorRunTime = _executorRunTime + def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value /** * The number of bytes this task transmitted back to the driver as the TaskResult */ private var _resultSize: Long = _ - def resultSize = _resultSize + def resultSize: Long = _resultSize private[spark] def setResultSize(value: Long) = _resultSize = value @@ -74,31 +71,31 @@ class TaskMetrics extends Serializable { * Amount of time the JVM spent in garbage collection while executing this task */ private var _jvmGCTime: Long = _ - def jvmGCTime = _jvmGCTime + def jvmGCTime: Long = _jvmGCTime private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value /** * Amount of time spent serializing the task result */ private var _resultSerializationTime: Long = _ - def resultSerializationTime = _resultSerializationTime + def resultSerializationTime: Long = _resultSerializationTime private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value /** * The number of in-memory bytes spilled by this task */ private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value + def memoryBytesSpilled: Long = _memoryBytesSpilled + private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value + private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value /** * The number of on-disk bytes spilled by this task */ private var _diskBytesSpilled: Long = _ - def diskBytesSpilled = _diskBytesSpilled - def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value + def diskBytesSpilled: Long = _diskBytesSpilled + def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -106,7 +103,7 @@ class TaskMetrics extends Serializable { */ private var _inputMetrics: Option[InputMetrics] = None - def inputMetrics = _inputMetrics + def inputMetrics: Option[InputMetrics] = _inputMetrics /** * This should only be used when recreating TaskMetrics, not when updating input metrics in @@ -128,7 +125,7 @@ class TaskMetrics extends Serializable { */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None - def shuffleReadMetrics = _shuffleReadMetrics + def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** * This should only be used when recreating TaskMetrics, not when updating read metrics in @@ -177,17 +174,18 @@ class TaskMetrics extends Serializable { * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, * we can store all the different inputMetrics (one per readMethod). */ - private[spark] def getInputMetricsForReadMethod( - readMethod: DataReadMethod): InputMetrics = synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { + synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } } } @@ -256,14 +254,14 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _bytesRead: Long = _ def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long) = _bytesRead += bytes + def incBytesRead(bytes: Long): Unit = _bytesRead += bytes /** * Total records read. */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long) = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -293,15 +291,15 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) { * Total bytes written */ private var _bytesWritten: Long = _ - def bytesWritten = _bytesWritten - private[spark] def setBytesWritten(value : Long) = _bytesWritten = value + def bytesWritten: Long = _bytesWritten + private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value /** * Total records written */ private var _recordsWritten: Long = 0L - def recordsWritten = _recordsWritten - private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value + def recordsWritten: Long = _recordsWritten + private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value } /** @@ -314,7 +312,7 @@ class ShuffleReadMetrics extends Serializable { * Number of remote blocks fetched in this shuffle by this task */ private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched = _remoteBlocksFetched + def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value @@ -322,7 +320,7 @@ class ShuffleReadMetrics extends Serializable { * Number of local blocks fetched in this shuffle by this task */ private var _localBlocksFetched: Int = _ - def localBlocksFetched = _localBlocksFetched + def localBlocksFetched: Int = _localBlocksFetched private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value @@ -332,7 +330,7 @@ class ShuffleReadMetrics extends Serializable { * still not finished processing block A, it is not considered to be blocking on block B. */ private var _fetchWaitTime: Long = _ - def fetchWaitTime = _fetchWaitTime + def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value @@ -340,7 +338,7 @@ class ShuffleReadMetrics extends Serializable { * Total number of remote bytes read from the shuffle by this task */ private var _remoteBytesRead: Long = _ - def remoteBytesRead = _remoteBytesRead + def remoteBytesRead: Long = _remoteBytesRead private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value @@ -348,24 +346,24 @@ class ShuffleReadMetrics extends Serializable { * Shuffle data that was read from the local disk (as opposed to from a remote executor). */ private var _localBytesRead: Long = _ - def localBytesRead = _localBytesRead + def localBytesRead: Long = _localBytesRead private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value /** * Total bytes fetched in the shuffle by this task (both remote and local). */ - def totalBytesRead = _remoteBytesRead + _localBytesRead + def totalBytesRead: Long = _remoteBytesRead + _localBytesRead /** * Number of blocks fetched in this shuffle by this task (remote or local) */ - def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched + def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched /** * Total number of records read from the shuffle by this task */ private var _recordsRead: Long = _ - def recordsRead = _recordsRead + def recordsRead: Long = _recordsRead private[spark] def incRecordsRead(value: Long) = _recordsRead += value private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } @@ -380,7 +378,7 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for the shuffle by this task */ @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten = _shuffleBytesWritten + def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value @@ -388,7 +386,7 @@ class ShuffleWriteMetrics extends Serializable { * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime= _shuffleWriteTime + def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value @@ -396,7 +394,7 @@ class ShuffleWriteMetrics extends Serializable { * Total number of records written to the shuffle by this task */ @volatile private var _shuffleRecordsWritten: Long = _ - def shuffleRecordsWritten = _shuffleRecordsWritten + def shuffleRecordsWritten: Long = _shuffleRecordsWritten private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 593a62b3e3b32..6cda7772f77bc 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -73,16 +73,16 @@ private[spark] abstract class StreamBasedRecordReader[T]( private var key = "" private var value: T = null.asInstanceOf[T] - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} - override def close() = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: T = value - override def nextKeyValue = { + override def nextKeyValue: Boolean = { if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) @@ -119,7 +119,8 @@ private[spark] class StreamRecordReader( * The format for the PortableDataStream files */ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { - override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) + : CombineFileRecordReader[String, PortableDataStream] = { new CombineFileRecordReader[String, PortableDataStream]( split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) } @@ -204,7 +205,7 @@ class PortableDataStream( /** * Close the file (if it is currently open) */ - def close() = { + def close(): Unit = { if (isOpen) { try { fileIn.close() diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 21b782edd2a9e..87c2aa481095d 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -52,7 +52,7 @@ trait SparkHadoopMapRedUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 3340673f91156..cfd20392d12f1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -45,7 +45,7 @@ trait SparkHadoopMapReduceUtil { jobId: Int, isMap: Boolean, taskId: Int, - attemptId: Int) = { + attemptId: Int): TaskAttemptID = { val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 345db36630fd5..9150ad35712a1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} @@ -84,7 +85,7 @@ private[spark] class MetricsSystem private ( /** * Get any UI handlers used by this metrics system; can only be called after start(). */ - def getServletHandlers = { + def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") metricsServlet.map(_.getHandlers).getOrElse(Array()) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 2f65bc8b46609..0c2e212a33074 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -30,8 +30,12 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils._ -private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) extends Sink { +private[spark] class MetricsServlet( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -45,10 +49,12 @@ private[spark] class MetricsServlet(val property: Properties, val registry: Metr val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[ServletContextHandler]( - createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) - ) + def getHandlers: Array[ServletContextHandler] = { + Array[ServletContextHandler]( + createServletHandler(servletPath, + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + ) + } def getMetricsSnapshot(request: HttpServletRequest): String = { mapper.writeValueAsString(registry) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 0d83d8c425ca4..9fad4e7deacb6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink private[spark] trait Sink { - def start: Unit - def stop: Unit + def start(): Unit + def stop(): Unit def report(): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index a1a2c00ed1542..1ba25aa74aa02 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -32,11 +32,11 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int): BlockMessage = blockMessages(i) - def iterator = blockMessages.iterator + def iterator: Iterator[BlockMessage] = blockMessages.iterator - def length = blockMessages.length + def length: Int = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index 3b245c5c7a4f3..9a9e22b0c2366 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -31,9 +31,9 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val initialSize = currentSize() var gotChunkForSendingOnce = false - def size = initialSize + def size: Int = initialSize - def currentSize() = { + def currentSize(): Int = { if (buffers == null || buffers.isEmpty) { 0 } else { @@ -100,11 +100,11 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: buffers.foreach(_.flip) } - def hasAckId() = (ackId != 0) + def hasAckId(): Boolean = ackId != 0 - def isCompletelyReceived() = !buffers(0).hasRemaining + def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - override def toString = { + override def toString: String = { if (hasAckId) { "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index c2d9578be7ebb..04eb2bf9ba4ab 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -101,9 +101,11 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, socketRemoteConnectionManagerId } - def key() = channel.keyFor(selector) + def key(): SelectionKey = channel.keyFor(selector) - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + def getRemoteAddress(): InetSocketAddress = { + channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + } // Returns whether we have to register for further reads or not. def read(): Boolean = { @@ -280,7 +282,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, /* channel.socket.setSendBufferSize(256 * 1024) */ - override def getRemoteAddress() = address + override def getRemoteAddress(): InetSocketAddress = address val DEFAULT_INTEREST = SelectionKey.OP_READ diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index 764dc5e5503ed..b3b281ff465f1 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -18,7 +18,9 @@ package org.apache.spark.network.nio private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString: String = { + connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + } } private[nio] object ConnectionId { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index ee22c6656e69e..741fe3e1ea750 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -188,7 +188,7 @@ private[nio] class ConnectionManager( private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() + override def run(): Unit = ConnectionManager.this.run() } selectorThread.setDaemon(true) // start this thread last, since it invokes run(), which accesses members above diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index cbb37ec5ced1f..1cd13d887c6f6 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -26,7 +26,7 @@ private[nio] case class ConnectionManagerId(host: String, port: Int) { Utils.checkHost(host) assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) + def toSocketAddress(): InetSocketAddress = new InetSocketAddress(host, port) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index fb4a979b824c3..85d2fe2bf9c20 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -42,7 +42,9 @@ private[nio] abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString: String = { + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + } } @@ -51,7 +53,7 @@ private[nio] object Message { var lastId = 1 - def getNewId() = synchronized { + def getNewId(): Int = synchronized { lastId += 1 if (lastId == 0) { lastId += 1 diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index 278c5ac356ef2..a4568e849fa13 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining + val size: Int = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { + lazy val buffers: ArrayBuffer[ByteBuffer] = { val ab = new ArrayBuffer[ByteBuffer]() ab += header.buffer if (buffer != null) { @@ -35,7 +35,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { ab } - override def toString = { + override def toString: String = { "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index 6e20f291c5cec..7b3da4bb9d5ee 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -50,8 +50,10 @@ private[nio] class MessageChunkHeader( flip.asInstanceOf[ByteBuffer] } - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + override def toString: String = { + "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index cadd0c7ed19ba..53c4b32c95ab3 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -99,7 +99,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { case None => "(partial: " + initialValue + ")" } } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) + def getFinalValueInternal(): Option[T] = PartialResult.this.getFinalValueInternal().map(f) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 1cbd684224b7c..9059eb13bb5d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index b073eba8a1574..5117ccfabfcc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -186,7 +186,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: override val isEmpty = !it.hasNext // initializes/resets to start iterating from the beginning - def resetIterator() = { + def resetIterator(): Iterator[(String, Partition)] = { val iterators = (0 to 2).map( x => prev.partitions.iterator.flatMap(p => { if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None @@ -196,10 +196,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: } // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } + override def hasNext: Boolean = { !isEmpty } // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { + override def next(): (String, Partition) = { if (it.hasNext) { it.next() } else { @@ -237,7 +237,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: val rotIt = new LocationIterator(prev) // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { + if (!rotIt.hasNext) { (1 to targetLen).foreach(x => groupArr += PartitionGroup()) return } @@ -343,7 +343,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: private case class PartitionGroup(prefLoc: Option[String] = None) { var arr = mutable.ArrayBuffer[Partition]() - def size = arr.size + def size: Int = arr.size } private object PartitionGroup { diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 03afc289736bb..71e6e300fec5f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -191,25 +191,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 486e86ce1bb19..f77abac42b623 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,7 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes @@ -240,7 +239,7 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - override def getNext() = { + override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { @@ -337,11 +336,11 @@ private[spark] object HadoopRDD extends Logging { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - def putCachedMetadata(key: String, value: Any) = + private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) /** Add Hadoop configuration specific to a single partition and attempt. */ @@ -371,7 +370,7 @@ private[spark] object HadoopRDD extends Logging { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[HadoopPartition] val inputSplit = partition.inputSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index e2267861e79df..0c28f045e46e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.sql.{Connection, ResultSet} +import java.sql.{PreparedStatement, Connection, ResultSet} import scala.reflect.ClassTag @@ -28,8 +28,9 @@ import org.apache.spark.util.NextIterator import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx + override def index: Int = idx } + // TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** * An RDD that executes an SQL query on a JDBC connection and reads results. @@ -70,7 +71,8 @@ class JdbcRDD[T: ClassTag]( }).toArray } - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] + { context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() @@ -88,7 +90,7 @@ class JdbcRDD[T: ClassTag]( stmt.setLong(2, part.upper) val rs = stmt.executeQuery() - override def getNext: T = { + override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4883fb828814c..a838aac6e8d1a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -31,6 +31,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = + override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7fb94840df99c..2ab967f4bb313 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -238,7 +238,7 @@ private[spark] object NewHadoopRDD { override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Partition, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext): Iterator[U] = { val partition = split.asInstanceOf[NewHadoopPartition] val inputSplit = partition.serializableHadoopSplit.value f(inputSplit, firstParent[T].iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index f12d0cffaba34..e2394e28f8d26 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -98,7 +98,7 @@ private[spark] class ParallelCollectionRDD[T: ClassTag]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = { + override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index f781a8d776f2a..a00f4c1cdff91 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -40,7 +40,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = { + override def getParents(partitionId: Int): List[Int] = { List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) } } @@ -59,8 +59,10 @@ class PartitionPruningRDD[T: ClassTag]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) + } override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions @@ -74,7 +76,7 @@ object PartitionPruningRDD { * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD * when its type T is not known at compile time. */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean): PartitionPruningRDD[T] = { new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index ed79032893d33..dc60d48927624 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -149,10 +149,10 @@ private[spark] class PipedRDD[T: ClassTag]( }.start() // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines + val lines = Source.fromInputStream(proc.getInputStream).getLines() new Iterator[String] { - def next() = lines.next() - def hasNext = { + def next(): String = lines.next() + def hasNext: Boolean = { if (lines.hasNext) { true } else { @@ -162,7 +162,7 @@ private[spark] class PipedRDD[T: ClassTag]( } // cleanup task working directory if used - if (workInTaskDirectory == true) { + if (workInTaskDirectory) { scala.util.control.Exception.ignoring(classOf[IOException]) { Utils.deleteRecursively(new File(taskDirectory)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a4c74ed03e330..ddbfd5624e741 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -186,7 +186,7 @@ abstract class RDD[T: ClassTag]( } /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel + def getStorageLevel: StorageLevel = storageLevel // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed @@ -746,13 +746,13 @@ abstract class RDD[T: ClassTag]( def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) => new Iterator[(T, U)] { - def hasNext = (thisIter.hasNext, otherIter.hasNext) match { + def hasNext: Boolean = (thisIter.hasNext, otherIter.hasNext) match { case (true, true) => true case (false, false) => false case _ => throw new SparkException("Can only zip RDDs with " + "same number of elements in each partition") } - def next = (thisIter.next, otherIter.next) + def next(): (T, U) = (thisIter.next(), otherIter.next()) } } } @@ -868,8 +868,8 @@ abstract class RDD[T: ClassTag]( // Our partitioner knows how to handle T (which, since we have a partitioner, is // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + override def numPartitions: Int = p.numPartitions + override def getPartition(k: Any): Int = p.getPartition(k.asInstanceOf[(Any, _)]._1) } // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies // anyway, and when calling .keys, will not have a partitioner set, even though @@ -1394,7 +1394,7 @@ abstract class RDD[T: ClassTag]( } /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ - def context = sc + def context: SparkContext = sc /** * Private API for changing an RDD's ClassTag. diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index d9fe6847254fa..2dc47f95937cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,14 +17,12 @@ package org.apache.spark.rdd -import scala.reflect.ClassTag - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx + override val index: Int = idx override def hashCode(): Int = idx } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index ed24ea22a661c..c27f435eb9c5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -105,7 +105,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { + def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index aece683ff3199..4239e7e22af89 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -44,7 +44,7 @@ private[spark] class UnionPartition[T: ClassTag]( var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) - def preferredLocations() = rdd.preferredLocations(parentPartition) + def preferredLocations(): Seq[String] = rdd.preferredLocations(parentPartition) override val index: Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 95b2dd954e9f4..d0be304762e1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -32,7 +32,7 @@ private[spark] class ZippedPartitionsPartition( override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues + def partitions: Seq[Partition] = partitionValues @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index fa83372bb4d11..e0edd7d4ae968 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -39,8 +39,11 @@ class AccumulableInfo ( } object AccumulableInfo { - def apply(id: Long, name: String, update: Option[String], value: String) = + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { new AccumulableInfo(id, name, update, value) + } - def apply(id: Long, name: String, value: String) = new AccumulableInfo(id, name, None, value) + def apply(id: Long, name: String, value: String): AccumulableInfo = { + new AccumulableInfo(id, name, None, value) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8feac6cb6b7a1..b405bd3338e7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,7 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) - def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { + def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 34fa6d27c3a45..c0d889360ae99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -149,47 +149,60 @@ private[spark] class EventLoggingListener( } // Events that do not trigger a flush - override def onStageSubmitted(event: SparkListenerStageSubmitted) = - logEvent(event) - override def onTaskStart(event: SparkListenerTaskStart) = - logEvent(event) - override def onTaskGettingResult(event: SparkListenerTaskGettingResult) = - logEvent(event) - override def onTaskEnd(event: SparkListenerTaskEnd) = - logEvent(event) - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate) = - logEvent(event) + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = logEvent(event) + + override def onTaskStart(event: SparkListenerTaskStart): Unit = logEvent(event) + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = logEvent(event) + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event) + + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event) // Events that trigger a flush - override def onStageCompleted(event: SparkListenerStageCompleted) = - logEvent(event, flushLogger = true) - override def onJobStart(event: SparkListenerJobStart) = - logEvent(event, flushLogger = true) - override def onJobEnd(event: SparkListenerJobEnd) = + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded) = + } + + override def onJobStart(event: SparkListenerJobStart): Unit = logEvent(event, flushLogger = true) + + override def onJobEnd(event: SparkListenerJobEnd): Unit = logEvent(event, flushLogger = true) + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { logEvent(event, flushLogger = true) - override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved) = + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { logEvent(event, flushLogger = true) - override def onUnpersistRDD(event: SparkListenerUnpersistRDD) = + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { logEvent(event, flushLogger = true) - override def onApplicationStart(event: SparkListenerApplicationStart) = + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { logEvent(event, flushLogger = true) - override def onApplicationEnd(event: SparkListenerApplicationEnd) = + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { logEvent(event, flushLogger = true) - override def onExecutorAdded(event: SparkListenerExecutorAdded) = + } + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { logEvent(event, flushLogger = true) - override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { logEvent(event, flushLogger = true) + } // No-op because logging every update would be overkill - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. */ - def stop() = { + def stop(): Unit = { writer.foreach(_.close()) val target = new Path(logPath) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 8aa528ac573d0..e55b76c36cc5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -57,7 +57,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val stageIdToJobId = new HashMap[Int, Int] private val jobIdToStageIds = new HashMap[Int, Seq[Int]] private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } createLogDir() diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 29879b374b801..382b09422a4a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -34,7 +34,7 @@ private[spark] class JobWaiter[T]( @volatile private var _jobFinished = totalTasks == 0 - def jobFinished = _jobFinished + def jobFinished: Boolean = _jobFinished // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 759df023a6dcf..a3caa9f000c89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -160,7 +160,7 @@ private[spark] object OutputCommitCoordinator { class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) extends Actor with ActorLogReceive with Logging { - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) case StopCoordinator => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 4a9ff918afe25..e074ce6ebff0b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -64,5 +64,5 @@ private[spark] class ResultTask[T, U]( // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" + override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 79709089c0da4..fd0d484b45460 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -47,7 +47,7 @@ private[spark] class ShuffleMapTask( /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index = 0 }, null) + this(0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -83,5 +83,5 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 52720d48ca67f..b711ff209af94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -300,7 +300,7 @@ private[spark] object StatsReportListener extends Logging { } def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d: Double) = format.format(d) + def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -346,7 +346,7 @@ private[spark] object StatsReportListener extends Logging { /** * Reformat a time interval in milliseconds to a prettier format for output */ - def millisToString(ms: Long) = { + def millisToString(ms: Long): String = { val (size, units) = if (ms > hours) { (ms.toDouble / hours, "hours") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b89..4cbc6e84a6bdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -133,7 +133,7 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - override def toString = "Stage " + id + override def toString: String = "Stage " + id override def hashCode(): Int = id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6fa1f2c880f7a..132a9ced77700 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,9 +81,11 @@ class TaskInfo( def status: String = { if (running) { - "RUNNING" - } else if (gettingResult) { - "GET RESULT" + if (gettingResult) { + "GET RESULT" + } else { + "RUNNING" + } } else if (failed) { "FAILED" } else if (successful) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 10c685f29d3ac..da07ce2c6ea49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -29,23 +29,22 @@ private[spark] sealed trait TaskLocation { /** * A location that includes both a host and an executor id on that host. */ -private [spark] case class ExecutorCacheTaskLocation(override val host: String, - val executorId: String) extends TaskLocation { -} +private [spark] +case class ExecutorCacheTaskLocation(override val host: String, executorId: String) + extends TaskLocation /** * A location on a host. */ private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { - override def toString = host + override def toString: String = host } /** * A location on a host that is cached by HDFS. */ -private [spark] case class HDFSCacheTaskLocation(override val host: String) - extends TaskLocation { - override def toString = TaskLocation.inMemoryLocationTag + host +private [spark] case class HDFSCacheTaskLocation(override val host: String) extends TaskLocation { + override def toString: String = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { @@ -54,14 +53,16 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" - def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) + def apply(host: String, executorId: String): TaskLocation = { + new ExecutorCacheTaskLocation(host, executorId) + } /** * Create a TaskLocation from a string returned by getPreferredLocations. * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the * location is cached. */ - def apply(str: String) = { + def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { new HostTaskLocation(str) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f33fd4450b2a6..076b36e86c0ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -373,17 +373,17 @@ private[spark] class TaskSchedulerImpl( } def handleSuccessfulTask( - taskSetManager: TaskSetManager, - tid: Long, - taskResult: DirectTaskResult[_]) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskResult: DirectTaskResult[_]): Unit = synchronized { taskSetManager.handleSuccessfulTask(tid, taskResult) } def handleFailedTask( - taskSetManager: TaskSetManager, - tid: Long, - taskState: TaskState, - reason: TaskEndReason) = synchronized { + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: TaskEndReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to @@ -423,7 +423,7 @@ private[spark] class TaskSchedulerImpl( starvationTimer.cancel() } - override def defaultParallelism() = backend.defaultParallelism() + override def defaultParallelism(): Int = backend.defaultParallelism() // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 529237f0d35dc..d509881c74fef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.Arrays +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -29,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -97,7 +99,8 @@ private[spark] class TaskSetManager( var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] - override def runningTasks = runningTasksSet.size + + override def runningTasks: Int = runningTasksSet.size // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the @@ -168,9 +171,9 @@ private[spark] class TaskSetManager( var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level - override def schedulableQueue = null + override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null - override def schedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.NONE var emittedTaskSizeWarning = false @@ -585,7 +588,7 @@ private[spark] class TaskSetManager( /** * Marks the task as getting result and notifies the DAG Scheduler */ - def handleTaskGettingResult(tid: Long) = { + def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) @@ -612,7 +615,7 @@ private[spark] class TaskSetManager( /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ - def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { + def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index info.markSuccessful() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 87ebf31139ce9..5d258d9da4d1a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -85,7 +85,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receiveWithLogging = { + def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisterExecutor(executorId, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f14aaeea0a25c..5a38ad9f2b12c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] abstract class YarnSchedulerBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager => logInfo(s"ApplicationMaster registered as $sender") amActor = Some(sender) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala index aa3ec0f8cfb9c..8df4f3b554c41 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -24,7 +24,7 @@ private[spark] object MemoryUtils { val OVERHEAD_FRACTION = 0.10 val OVERHEAD_MINIMUM = 384 - def calculateTotalMemory(sc: SparkContext) = { + def calculateTotalMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 06bb527522141..b381436839227 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -387,7 +387,7 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def defaultParallelism(): Int = sc.conf.getInt("spark.default.parallelism", 8) override def applicationId(): String = Option(appId).getOrElse { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index d95426d918e19..eb3f999b5b375 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -59,7 +59,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -117,7 +117,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! ReviveOffers } - override def defaultParallelism() = + override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 1baa0e009f3ae..dfbde7c8a1b0d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -59,9 +59,10 @@ private[spark] class JavaSerializationStream( } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { + extends DeserializationStream { + private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index dc7aa99738c17..579fb6624e692 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,10 +49,20 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSize = - (conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) * 1024 * 1024).toInt + private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) + if (bufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + + s"2048 mb, got: + $bufferSizeMb mb.") + } + private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + + val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + if (maxBufferSizeMb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + s"2048 mb, got: + $maxBufferSizeMb mb.") + } + private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 - private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) private val userRegistrator = conf.getOption("spark.kryo.registrator") @@ -60,7 +70,7 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) - def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { val instantiator = new EmptyScalaKryoInstantiator diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 7de2f9cbb2866..d0178dfde6935 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -106,12 +106,13 @@ class FileShuffleBlockManager(conf: SparkConf) * when the writers are closed successfully */ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, - writeMetrics: ShuffleWriteMetrics) = { + writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null + val openStartTime = System.nanoTime val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -135,6 +136,9 @@ class FileShuffleBlockManager(conf: SparkConf) blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, so should be included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { if (consolidateShuffleFiles) { @@ -268,7 +272,7 @@ object FileShuffleBlockManager { new PrimitiveVector[Long]() } - def apply(bucketId: Int) = files(bucketId) + def apply(bucketId: Int): File = files(bucketId) def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { assert(offsets.length == lengths.length) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index b292587d37028..87fd161e06c85 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -80,7 +80,7 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { * end of the output file. This will be used by getBlockLocation to figure out where each block * begins and ends. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = { + def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) try { @@ -121,5 +121,5 @@ class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { } } - override def stop() = {} + override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index fa2e617762f55..55ea0f17b156a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -63,6 +63,9 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter.insertAll(records) } + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 1f012941c85ab..c186fd360fef6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -35,13 +35,13 @@ sealed abstract class BlockId { def name: String // convenience methods - def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None - def isRDD = isInstanceOf[RDDBlockId] - def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD: Boolean = isInstanceOf[RDDBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] - override def toString = name - override def hashCode = name.hashCode + override def toString: String = name + override def hashCode: Int = name.hashCode override def equals(other: Any): Boolean = other match { case o: BlockId => getClass == o.getClass && name.equals(o.name) case _ => false @@ -50,54 +50,54 @@ sealed abstract class BlockId { @DeveloperApi case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - def name = "rdd_" + rddId + "_" + splitIndex + override def name: String = "rdd_" + rddId + "_" + splitIndex } // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { - def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) + override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @DeveloperApi case class TaskResultBlockId(taskId: Long) extends BlockId { - def name = "taskresult_" + taskId + override def name: String = "taskresult_" + taskId } @DeveloperApi case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { - def name = "input-" + streamId + "-" + uniqueId + override def name: String = "input-" + streamId + "-" + uniqueId } /** Id associated with temporary local data managed as blocks. Not serializable. */ private[spark] case class TempLocalBlockId(id: UUID) extends BlockId { - def name = "temp_local_" + id + override def name: String = "temp_local_" + id } /** Id associated with temporary shuffle data managed as blocks. Not serializable. */ private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId { - def name = "temp_shuffle_" + id + override def name: String = "temp_shuffle_" + id } // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { - def name = "test_" + id + override def name: String = "test_" + id } @DeveloperApi @@ -112,7 +112,7 @@ object BlockId { val TEST = "test_(.*)".r /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String) = id match { + def apply(id: String): BlockId = id match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 80d66e59132da..1dff09a75d038 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -535,9 +535,14 @@ private[spark] class BlockManager( /* We'll store the bytes in memory if the block's storage level includes * "memory serialized", or if it should be cached as objects in memory * but we only requested its serialized bytes. */ - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) + memoryStore.putBytes(blockId, bytes.limit, () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we cannot + // put it into MemoryStore, copyForMemory should not be created. That's why this + // action is put into a `() => ByteBuffer` and created lazily. + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + }) bytes.rewind() } if (!asBlockResult) { @@ -991,15 +996,23 @@ private[spark] class BlockManager( putIterator(blockId, Iterator(value), level, tellMaster) } + def dropFromMemory( + blockId: BlockId, + data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + dropFromMemory(blockId, () => data) + } + /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * + * If `data` is not put on disk, it won't be created. + * * Return the block status if the given block has been updated, else None. */ def dropFromMemory( blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -1023,7 +1036,7 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") - data match { + data() match { case Left(elements) => diskStore.putArray(blockId, elements, level, returnValues = false) case Right(bytes) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b177a59c721df..a6f1ebf325a7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -77,11 +77,11 @@ class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = s"BlockManagerId($executorId, $host, $port)" + override def toString: String = s"BlockManagerId($executorId, $host, $port)" override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port - override def equals(that: Any) = that match { + override def equals(that: Any): Boolean = that match { case id: BlockManagerId => executorId == id.executorId && port == id.port && host == id.host case _ => @@ -100,10 +100,10 @@ private[spark] object BlockManagerId { * @param port Port of the block manager. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int) = + def apply(execId: String, host: String, port: Int): BlockManagerId = getCachedBlockManagerId(new BlockManagerId(execId, host, port)) - def apply(in: ObjectInput) = { + def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() obj.readExternal(in) getCachedBlockManagerId(obj) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 654796f23c96e..061964826f08b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -79,7 +79,7 @@ class BlockManagerMaster( * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. */ - def contains(blockId: BlockId) = { + def contains(blockId: BlockId): Boolean = { !getLocations(blockId).isEmpty } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 787b0f96bec32..5b5328016124e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,7 +52,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -421,7 +421,7 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8462871e798a5..52fb896c4e21f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -38,7 +38,7 @@ class BlockManagerSlaveActor( import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { + override def receiveWithLogging: PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 81164178b9e8e..f703e50b6b0ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -82,11 +82,13 @@ private[spark] class DiskBlockObjectWriter( { /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]) = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) - override def close() = out.close() - override def flush() = out.flush() + override def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + callWithTiming(out.write(b, off, len)) + } + override def close(): Unit = out.close() + override def flush(): Unit = out.flush() } /** The file channel, used for repositioning / truncating the file. */ @@ -141,8 +143,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - def sync = fos.getFD.sync() - callWithTiming(sync) + callWithTiming { + fos.getFD.sync() + } } objOut.close() diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 12cd8ea3bdf1f..2883137872600 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -47,6 +47,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content + // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private val shutdownHook = addShutdownHook() @@ -61,20 +63,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon val subDirId = (hash / localDirs.length) % subDirsPerLocalDir // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") - } - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -91,7 +90,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories - subDirs.flatten.filter(_ != null).flatMap { dir => + subDirs.flatMap { dir => + dir.synchronized { + // Copy the content of dir because it may be modified in other threads + dir.clone() + } + }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files else Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 132502b75f8cd..95e2d688d9b17 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,5 +24,7 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 1be860aea63d0..ed609772e6979 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -98,6 +98,26 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and + * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. + * + * The caller should guarantee that `size` is correct. + */ + def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { + // Work on a duplicate - since the original input might be used elsewhere. + lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] + val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val data = + if (putAttempt.success) { + assert(bytes.limit == size) + Right(bytes.duplicate()) + } else { + null + } + PutResult(size, data, putAttempt.droppedBlocks) + } + override def putArray( blockId: BlockId, values: Array[Any], @@ -312,11 +332,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId.asRDDId.map(_.rddId) } + private def tryToPut( + blockId: BlockId, + value: Any, + size: Long, + deserialized: Boolean): ResultWithDroppedBlocks = { + tryToPut(blockId, () => value, size, deserialized) + } + /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size * must also be passed by the caller. * + * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be + * created to avoid OOM since it may be a big ByteBuffer. + * * Synchronize on `accountingLock` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for @@ -326,7 +357,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ private def tryToPut( blockId: BlockId, - value: Any, + value: () => Any, size: Long, deserialized: Boolean): ResultWithDroppedBlocks = { @@ -345,7 +376,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new MemoryEntry(value, size, deserialized) + val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size @@ -357,12 +388,12 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[Array[Any]]) + lazy val data = if (deserialized) { + Left(value().asInstanceOf[Array[Any]]) } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) + Right(value().asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 120c327a7e580..0186eb30a1905 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -36,7 +36,7 @@ class RDDInfo( def isCached: Boolean = (memSize + diskSize + tachyonSize > 0) && numCachedPartitions > 0 - override def toString = { + override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + "MemorySize: %s; TachyonSize: %s; DiskSize: %s").format( @@ -44,7 +44,7 @@ class RDDInfo( bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } - override def compare(that: RDDInfo) = { + override def compare(that: RDDInfo): Int = { this.id - that.id } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index e5e1cf5a69a19..134abea866218 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -50,11 +50,11 @@ class StorageLevel private( def this() = this(false, true, false, false) // For deserialization - def useDisk = _useDisk - def useMemory = _useMemory - def useOffHeap = _useOffHeap - def deserialized = _deserialized - def replication = _replication + def useDisk: Boolean = _useDisk + def useMemory: Boolean = _useMemory + def useOffHeap: Boolean = _useOffHeap + def deserialized: Boolean = _deserialized + def replication: Int = _replication assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") @@ -80,7 +80,7 @@ class StorageLevel private( false } - def isValid = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) def toInt: Int = { var ret = 0 @@ -183,7 +183,7 @@ object StorageLevel { useMemory: Boolean, useOffHeap: Boolean, deserialized: Boolean, - replication: Int) = { + replication: Int): StorageLevel = { getCachedStorageLevel( new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) } @@ -197,7 +197,7 @@ object StorageLevel { useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, - replication: Int = 1) = { + replication: Int = 1): StorageLevel = { getCachedStorageLevel(new StorageLevel(useDisk, useMemory, false, deserialized, replication)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index def49e80a3605..7d75929b96f75 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import scala.collection.mutable -import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -32,7 +31,7 @@ class StorageStatusListener extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() - def storageStatusList = executorIdToStorageStatus.values.toSeq + def storageStatusList: Seq[StorageStatus] = executorIdToStorageStatus.values.toSeq /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { @@ -56,7 +55,7 @@ class StorageStatusListener extends SparkListener { } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { @@ -67,7 +66,7 @@ class StorageStatusListener extends SparkListener { } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { updateStorageStatus(unpersistRDD.rddId) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 2ab6a8f3ec1d4..af873034215a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.TachyonURI -import tachyon.client.{TachyonFile, TachyonFS} +import tachyon.client.TachyonFS +import tachyon.client.TachyonFile import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null + val client = if (master != null && master != "") TachyonFS.get(master) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(new TachyonURI(file.getPath()), false) + client.delete(file.getPath(), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(new TachyonURI(file.getPath())) + client.exist(file.getPath()) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") + val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = new TachyonURI(s"$subDir/$filename") + val filePath = subDir + "/" + filename if(!client.exist(filePath)) { client.createFile(filePath) } @@ -101,7 +101,7 @@ private[spark] class TachyonBlockManager( // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore. private def createTachyonDirs(): Array[TachyonFile] = { - logDebug(s"Creating tachyon directories at root dirs '$rootDirs'") + logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") rootDirs.split(",").map { rootDir => var foundLocalDir = false @@ -113,21 +113,22 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") + val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) } } catch { case e: Exception => - logWarning(s"Attempt $tries to create tachyon dir $tachyonDir failed", e) + logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) } } if (!foundLocalDir) { - logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create tachyon dir in $rootDir") + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " + + rootDir) System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR) } - logInfo(s"Created tachyon directory at $tachyonDir") + logInfo("Created tachyon directory at " + tachyonDir) tachyonDir } } @@ -144,7 +145,7 @@ private[spark] class TachyonBlockManager( } } catch { case e: Exception => - logError(s"Exception while deleting tachyon spark dir: $tachyonDir", e) + logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } client.close() diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala index b86abbda1d3e7..65fa81704c365 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala @@ -24,5 +24,7 @@ import tachyon.client.TachyonFile * a length. */ private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) { - override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 0c24ad2760e08..adfa6bbada256 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,7 +60,7 @@ private[spark] class SparkUI private ( } initialize() - def getAppName = appName + def getAppName: String = appName /** Set the app name for this UI. */ def setAppName(name: String) { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index b5022fe853c49..f07864141a21c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -149,9 +149,11 @@ private[spark] object UIUtils extends Logging { } } - def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource + def prependBaseUri(basePath: String = "", resource: String = ""): String = { + uiRoot + basePath + resource + } - def commonHeaderNodes = { + def commonHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index fc1844600f1cb..5fbcd6bb8ad94 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui +import java.util.concurrent.Semaphore + import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} @@ -51,7 +53,7 @@ private[spark] object UIWorkloadGenerator { val nJobSet = args(2).toInt val sc = new SparkContext(conf) - def setProperties(s: String) = { + def setProperties(s: String): Unit = { if(schedulingMode == SchedulingMode.FAIR) { sc.setLocalProperty("spark.scheduler.pool", s) } @@ -59,7 +61,7 @@ private[spark] object UIWorkloadGenerator { } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = new Random().nextFloat() + def nextFloat(): Float = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), @@ -88,6 +90,8 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) + val barrier = new Semaphore(-nJobSet * jobs.size + 1) + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { @@ -99,12 +103,17 @@ private[spark] object UIWorkloadGenerator { } catch { case e: Exception => println("Job Failed: " + desc) + } finally { + barrier.release() } } }.start Thread.sleep(INTER_JOB_WAIT_MS) } } + + // Waiting for threads. + barrier.acquire() sc.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 3afd7ef07d7c9..69053fe44d7e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener +import org.apache.spark.storage.{StorageStatus, StorageStatusListener} import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { @@ -55,19 +55,19 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToShuffleWrite = HashMap[String, Long]() val executorToLogUrls = HashMap[String, Map[String, String]]() - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList - override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) = synchronized { + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 937d95a934b59..625596885faa1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -44,6 +44,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // These type aliases are public because they're used in the types of public fields: type JobId = Int + type JobGroupId = String type StageId = Int type StageAttemptId = Int type PoolName = String @@ -54,6 +55,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] // Stages: val pendingStages = new HashMap[StageId, StageInfo] @@ -73,7 +75,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Misc: val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + def blockManagerIds: Seq[BlockManagerId] = executorIdToBlockManagerId.values.toSeq var schedulingMode: Option[SchedulingMode] = None @@ -119,7 +121,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { Map( "jobIdToData" -> jobIdToData.size, "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size + "stageIdToStageInfo" -> stageIdToInfo.size, + "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, + // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: + "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size ) } @@ -140,13 +145,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (jobs.size > retainedJobs) { val toRemove = math.max(retainedJobs / 10, 1) jobs.take(toRemove).foreach { job => - jobIdToData.remove(job.jobId) + // Remove the job's UI data, if it exists + jobIdToData.remove(job.jobId).foreach { removedJob => + // A null jobGroupId is used for jobs that are run without a job group + val jobGroupId = removedJob.jobGroup.orNull + // Remove the job group -> job mapping entry, if it exists + jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => + jobsInGroup.remove(job.jobId) + // If this was the last job in this job group, remove the map entry for the job group + if (jobsInGroup.isEmpty) { + jobGroupToJobIds.remove(jobGroupId) + } + } + } } jobs.trimStart(toRemove) } } - override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { val jobGroup = for ( props <- Option(jobStart.properties); group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -158,6 +175,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) + // A null jobGroupId is used for jobs that are run without a job group + jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result @@ -182,7 +201,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) @@ -219,7 +238,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { val stage = stageCompleted.stageInfo stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { @@ -260,7 +279,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val stage = stageSubmitted.stageInfo activeStages(stage.stageId) = stage pendingStages.remove(stage.stageId) @@ -288,7 +307,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { @@ -312,7 +331,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // stageToTaskInfos already has the updated status. } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task // completion event is for. Let's just drop it here. This means we might have some speculation diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index b2bbfdee56946..7ffcf291b5cc6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -24,7 +24,7 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { val sc = parent.sc val killEnabled = parent.killEnabled - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) val listener = parent.jobProgressListener attachPage(new AllJobsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 110f8780a9a12..797c9404bc449 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import java.util.Date import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} +import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils @@ -170,7 +170,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo) = {acc.name}{acc.value} + def accumulableRow(acc: AccumulableInfo): Elem = + {acc.name}{acc.value} val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) @@ -268,11 +269,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - if (info.gettingResultTime > 0) { - (info.finishTime - info.gettingResultTime).toDouble - } else { - 0.0 - } + getGettingResultTime(info).toDouble } val gettingResultQuantiles = @@ -293,10 +290,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantiles(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) = { + def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) + : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator getDistributionQuantiles(data).map(d => {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} @@ -462,7 +460,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = info.gettingResultTime + val gettingResultTime = getGettingResultTime(info) val maybeAccumulators = info.accumulables val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} @@ -625,6 +623,19 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {errorSummary}{details} } + private def getGettingResultTime(info: TaskInfo): Long = { + if (info.gettingResultTime > 0) { + if (info.finishTime > 0) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + System.currentTimeMillis - info.gettingResultTime + } + } else { + 0L + } + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = if (info.gettingResult) { @@ -636,6 +647,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val executorOverhead = (metrics.executorDeserializeTime + metrics.resultSerializationTime) - math.max(0, totalExecutionTime - metrics.executorRunTime - executorOverhead) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 937261de00e3a..1bd2d87e00796 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -32,10 +32,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" attachPage(new StagePage(this)) attachPage(new PoolPage(this)) - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) - def handleKillRequest(request: HttpServletRequest) = { - if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { + def handleKillRequest(request: HttpServletRequest): Unit = { + if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index dbf1ceeda1878..711a3697bda15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -94,11 +94,11 @@ private[jobs] object UIData { var taskData = new HashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] - def hasInput = inputBytes > 0 - def hasOutput = outputBytes > 0 - def hasShuffleRead = shuffleReadTotalBytes > 0 - def hasShuffleWrite = shuffleWriteBytes > 0 - def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 + def hasInput: Boolean = inputBytes > 0 + def hasOutput: Boolean = outputBytes > 0 + def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 + def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index a81291d505583..045bd784990d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -40,10 +40,10 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - def storageStatusList = storageStatusListener.storageStatusList + def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ - def rddInfoList = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq + def rddInfoList: Seq[RDDInfo] = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq /** Update the storage info of the RDDs whose blocks are among the given updated blocks */ private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { @@ -56,19 +56,19 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar * Assumes the storage status list is fully up-to-date. This implies the corresponding * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener. */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { val metrics = taskEnd.taskMetrics if (metrics != null && metrics.updatedBlocks.isDefined) { updateRDDInfo(metrics.updatedBlocks.get) } } - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val rddInfos = stageSubmitted.stageInfo.rddInfos rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) } } - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { // Remove all partitions that are no longer cached in current completed stage val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet _rddInfoMap.retain { case (id, info) => @@ -76,7 +76,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar } } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized { + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { _rddInfoMap.remove(unpersistRDD.rddId) } } diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 390310243ee0a..9044aaeef2d48 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -27,8 +27,8 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat // scalastyle:on private[this] var completed = false - def next() = sub.next() - def hasNext = { + def next(): A = sub.next() + def hasNext: Boolean = { val r = sub.hasNext if (!r && !completed) { completed = true @@ -37,13 +37,13 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat r } - def completion() + def completion(): Unit } private[spark] object CompletionIterator { - def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = { new CompletionIterator[A,I](sub) { - def completion() = completionFunction + def completion(): Unit = completionFunction } } } diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index a465298c8c5ab..9aea8efa38c7a 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -57,7 +57,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va out.println } - def statCounter = StatCounter(data.slice(startIdx, endIdx)) + def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx)) /** * print a summary of this distribution to the given PrintStream. diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index cf89c1782fd67..1718554061985 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -39,31 +39,27 @@ private[spark] class ManualClock(private var time: Long) extends Clock { /** * @param timeToSet new time (in milliseconds) that the clock should represent */ - def setTime(timeToSet: Long) = - synchronized { - time = timeToSet - notifyAll() - } + def setTime(timeToSet: Long): Unit = synchronized { + time = timeToSet + notifyAll() + } /** * @param timeToAdd time (in milliseconds) to add to the clock's time */ - def advance(timeToAdd: Long) = - synchronized { - time += timeToAdd - notifyAll() - } + def advance(timeToAdd: Long): Unit = synchronized { + time += timeToAdd + notifyAll() + } /** * @param targetTime block until the clock time is set or advanced to at least this time * @return current time reported by the clock when waiting finishes */ - def waitTillTime(targetTime: Long): Long = - synchronized { - while (time < targetTime) { - wait(100) - } - getTimeMillis() + def waitTillTime(targetTime: Long): Long = synchronized { + while (time < targetTime) { + wait(100) } - + getTimeMillis() + } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index ac40f19ed6799..375ed430bde45 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -67,14 +67,15 @@ private[spark] object MetadataCleanerType extends Enumeration { type MetadataCleanerType = Value - def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = - "spark.cleaner.ttl." + which.toString + def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { + "spark.cleaner.ttl." + which.toString + } } // TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf) = { + def getDelaySeconds(conf: SparkConf): Int = { conf.getInt("spark.cleaner.ttl", -1) } diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index 74fa77b68de0b..dad888548ed10 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -43,7 +43,7 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef this } - override def toString = "(" + _1 + "," + _2 + ")" + override def toString: String = "(" + _1 + "," + _2 + ")" override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] } diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 6d8d9e8da3678..73d126ff6254e 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -22,7 +22,7 @@ package org.apache.spark.util */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { - override def findClass(name: String) = { + override def findClass(name: String): Class[_] = { super.findClass(name) } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala index 770ff9d5ad6ae..a06b6f84ef11b 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -27,7 +27,7 @@ import java.nio.channels.Channels */ private[spark] class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { - def value = buffer + def value: ByteBuffer = buffer private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val length = in.readInt() diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index d80eed455c427..8586da1996cf3 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -141,8 +141,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { object StatCounter { /** Build a StatCounter from a list of values. */ - def apply(values: TraversableOnce[Double]) = new StatCounter(values) + def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values) /** Build a StatCounter from a list of values passed as variable-length arguments. */ - def apply(values: Double*) = new StatCounter(values) + def apply(values: Double*): StatCounter = new StatCounter(values) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f5be5856c2109..310c0c109416c 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -82,7 +82,7 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo this } - override def update(key: A, value: B) = this += ((key, value)) + override def update(key: A, value: B): Unit = this += ((key, value)) override def apply(key: A): B = internalMap.apply(key) @@ -92,14 +92,14 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) def toMap: Map[A, B] = iterator.toMap /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) /** Remove entries with values that are no longer strongly reachable. */ def clearNullValues() { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 91d833295e376..0b5a914e7dbbf 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -42,8 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ - -import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ @@ -87,7 +85,7 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = + override def resolveClass(desc: ObjectStreamClass): Class[_] = Class.forName(desc.getName, false, loader) } ois.readObject.asInstanceOf[T] @@ -108,11 +106,10 @@ private[spark] object Utils extends Logging { /** Serialize via nested stream using specific serializer */ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)( - f: SerializationStream => Unit) = { + f: SerializationStream => Unit): Unit = { val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + override def write(b: Int): Unit = os.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len) }) try { f(osWrapper) @@ -123,10 +120,9 @@ private[spark] object Utils extends Logging { /** Deserialize via nested stream using specific serializer */ def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)( - f: DeserializationStream => Unit) = { + f: DeserializationStream => Unit): Unit = { val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - + override def read(): Int = is.read() override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) }) try { @@ -139,7 +135,7 @@ private[spark] object Utils extends Logging { /** * Get the ClassLoader which loaded Spark. */ - def getSparkClassLoader = getClass.getClassLoader + def getSparkClassLoader: ClassLoader = getClass.getClassLoader /** * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that @@ -148,7 +144,7 @@ private[spark] object Utils extends Logging { * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently * active loader when setting up ClassLoader delegation chains. */ - def getContextOrSparkClassLoader = + def getContextOrSparkClassLoader: ClassLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) /** Determines whether the provided class is loadable in the current thread. */ @@ -157,12 +153,14 @@ private[spark] object Utils extends Logging { } /** Preferred alternative to Class.forName(className) */ - def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + def classForName(className: String): Class[_] = { + Class.forName(className, true, getContextOrSparkClassLoader) + } /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { @@ -972,7 +970,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(new TachyonURI(dir.getPath()), true)) { + if (!client.delete(dir.getPath(), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } @@ -1559,7 +1557,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ - def getFormattedClassName(obj: AnyRef) = { + def getFormattedClassName(obj: AnyRef): String = { obj.getClass.getSimpleName.replace("$", "") } @@ -1572,7 +1570,7 @@ private[spark] object Utils extends Logging { } /** Return an empty JSON object */ - def emptyJson = JObject(List[JField]()) + def emptyJson: JsonAST.JObject = JObject(List[JField]()) /** * Return a Hadoop FileSystem with the scheme encoded in the given path. @@ -1620,7 +1618,7 @@ private[spark] object Utils extends Logging { /** * Indicates whether Spark is currently running unit tests. */ - def isTesting = { + def isTesting: Boolean = { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } @@ -1878,6 +1876,10 @@ private[spark] object Utils extends Logging { startService: Int => (T, Int), conf: SparkConf, serviceName: String = ""): (T, Int) = { + + require(startPort == 0 || (1024 <= startPort && startPort < 65536), + "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.") + val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index af1f64649f354..f79e8e0491ea1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -156,10 +156,10 @@ class BitSet(numBits: Int) extends Serializable { /** * Get an iterator over the set bits. */ - def iterator = new Iterator[Int] { + def iterator: Iterator[Int] = new Iterator[Int] { var ind = nextSetBit(0) override def hasNext: Boolean = ind >= 0 - override def next() = { + override def next(): Int = { val tmp = ind ind = nextSetBit(ind + 1) tmp diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a0f5a602de12..9ff4744593d4d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -159,7 +159,7 @@ class ExternalAppendOnlyMap[K, V, C]( val batchSizes = new ArrayBuffer[Long] // Flush the disk writer's contents to disk, and update relevant variables - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() @@ -355,7 +355,7 @@ class ExternalAppendOnlyMap[K, V, C]( val pairs: ArrayBuffer[(K, C)]) extends Comparable[StreamBuffer] { - def isEmpty = pairs.length == 0 + def isEmpty: Boolean = pairs.length == 0 // Invalid if there are no more pairs in this stream def minKeyHash: Int = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index d69f2d9048055..b962c101c91da 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -283,7 +283,7 @@ private[spark] class ExternalSorter[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables. // The writer is closed at the end of this process, and cannot be reused. - def flush() = { + def flush(): Unit = { val w = writer writer = null w.commitAndClose() @@ -352,6 +352,7 @@ private[spark] class ExternalSorter[K, V, C]( // Create our file writers if we haven't done so yet if (partitionWriters == null) { curWriteMetrics = new ShuffleWriteMetrics() + val openStartTime = System.nanoTime partitionWriters = Array.fill(numPartitions) { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -359,6 +360,10 @@ private[spark] class ExternalSorter[K, V, C]( val (blockId, file) = diskBlockManager.createTempShuffleBlock() blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) } // No need to sort stuff, just write each element out diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index b8de4ff9aa494..efc2482c74ddf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + if (k == null) { + haveNullValue + } else { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } + } + /** Get the value for a given key */ def apply(k: K): V = { if (k == null) { @@ -109,7 +118,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = -1 var nextPair: (K, V) = computeNextPair() @@ -132,9 +141,9 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4e363b74f4bef..1501111a06655 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -85,7 +85,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( protected var _bitset = new BitSet(_capacity) - def getBitSet = _bitset + def getBitSet: BitSet = _bitset // Init of the array in constructor (instead of in declaration) to work around a Scala compiler // specialization bug that would generate two arrays (one for Object and one for specialized T). @@ -122,7 +122,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def addWithoutResize(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 + var delta = 1 while (true) { if (!_bitset.get(pos)) { // This is a new key. @@ -134,14 +134,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( // Found an existing key. return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** @@ -163,27 +161,25 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def getPos(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 - val maxProbe = _data.size - while (i < maxProbe) { + var delta = 1 + while (true) { if (!_bitset.get(pos)) { return INVALID_POS } else if (k == _data(pos)) { return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** Return the value at the specified position. */ def getValue(pos: Int): T = _data(pos) - def iterator = new Iterator[T] { + def iterator: Iterator[T] = new Iterator[T] { var pos = nextPos(0) override def hasNext: Boolean = pos != INVALID_POS override def next(): T = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 2e1ef06cbc4e1..b4ec4ea521253 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -46,7 +46,12 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, private var _oldValues: Array[V] = null - override def size = _keySet.size + override def size: Int = _keySet.size + + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } /** Get the value for a given key */ def apply(k: K): V = { @@ -87,7 +92,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - override def iterator = new Iterator[(K, V)] { + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { var pos = 0 var nextPair: (K, V) = computeNextPair() @@ -103,9 +108,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, } } - def hasNext = nextPair != null + def hasNext: Boolean = nextPair != null - def next() = { + def next(): (K, V) = { val pair = nextPair nextPair = computeNextPair() pair diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index c5268c0fae0ef..bdbca00a00622 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -32,7 +32,7 @@ private[spark] object Utils { */ def takeOrdered[T](input: Iterator[T], num: Int)(implicit ord: Ordering[T]): Iterator[T] = { val ordering = new GuavaOrdering[T] { - override def compare(l: T, r: T) = ord.compare(l, r) + override def compare(l: T, r: T): Int = ord.compare(l, r) } collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator } diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 1d5467060623c..14b6ba4af489a 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -121,7 +121,7 @@ private[spark] object FileAppender extends Logging { val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT) val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT) - def createTimeBasedAppender() = { + def createTimeBasedAppender(): FileAppender = { val validatedParams: Option[(Long, String)] = rollingInterval match { case "daily" => logInfo(s"Rolling executor logs enabled for $file with daily rolling") @@ -149,7 +149,7 @@ private[spark] object FileAppender extends Logging { } } - def createSizeBasedAppender() = { + def createSizeBasedAppender(): FileAppender = { rollingSizeBytes match { case IntParam(bytes) => logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes") diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 76e7a2760bcd1..786b97ad7b9ec 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -105,7 +105,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals private val rng: Random = new XORShiftRandom - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (ub - lb <= 0.0) { @@ -131,7 +131,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals def cloneComplement(): BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, !complement) - override def clone = new BernoulliCellSampler[T](lb, ub, complement) + override def clone: BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, complement) } @@ -153,7 +153,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T private val rng: Random = RandomSampler.newDefaultRNG - override def setSeed(seed: Long) = rng.setSeed(seed) + override def setSeed(seed: Long): Unit = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { @@ -167,7 +167,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T } } - override def clone = new BernoulliSampler[T](fraction) + override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction) } @@ -209,7 +209,7 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] } } - override def clone = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) } @@ -228,15 +228,18 @@ class GapSamplingIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -244,21 +247,21 @@ class GapSamplingIterator[T: ClassTag]( override def next(): T = { val r = data.next() - advance + advance() r } private val lnq = math.log1p(-f) /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / lnq).toInt iterDrop(k) } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. @@ -279,15 +282,18 @@ class GapSamplingReplacementIterator[T: ClassTag]( val arrayClass = Array.empty[T].iterator.getClass val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass data.getClass match { - case `arrayClass` => ((n: Int) => { data = data.drop(n) }) - case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) - case _ => ((n: Int) => { + case `arrayClass` => + (n: Int) => { data = data.drop(n) } + case `arrayBufferClass` => + (n: Int) => { data = data.drop(n) } + case _ => + (n: Int) => { var j = 0 while (j < n && data.hasNext) { data.next() j += 1 } - }) + } } } @@ -300,7 +306,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( override def next(): T = { val r = v rep -= 1 - if (rep <= 0) advance + if (rep <= 0) advance() r } @@ -309,7 +315,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is * q is the probabililty of Poisson(0; f) */ - private def advance: Unit = { + private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) val k = (math.log(u) / (-f)).toInt iterDrop(k) @@ -343,7 +349,7 @@ class GapSamplingReplacementIterator[T: ClassTag]( } /** advance to first sample as part of object construction. */ - advance + advance() // Attempting to invoke this closer to the top with other object initialization // was causing it to break in strange ways, so I'm invoking it last, which seems to // work reliably. diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 2ae308dacf1ae..9e29bf9d61f17 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -311,7 +311,7 @@ private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted: var acceptBound: Double = Double.NaN // upper bound for accepting item instantly var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist - def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN + def areBoundsEmpty: Boolean = acceptBound.isNaN || waitListBound.isNaN def merge(other: Option[AcceptanceResult]): Unit = { if (other.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 467b890fb4bb9..c4a7b4441c85c 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -83,7 +83,7 @@ private[spark] object XORShiftRandom { * @return Map of execution times for {@link java.util.Random java.util.Random} * and XORShift */ - def benchmark(numIters: Int) = { + def benchmark(numIters: Int): Map[String, Long] = { val seed = 1L val million = 1e6.toInt diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 4cd0f97368ca3..97079382c716f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -235,6 +235,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithDoubleValuesAtMinMax") { + val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3)) + assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2) + assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) @@ -248,7 +254,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { } test("WorksWithoutBucketsForLargerDatasets") { - // Verify the case of slighly larger datasets + // Verify the case of slightly larger datasets val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(8) val expectedHistogramResults = @@ -259,17 +265,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } - test("WorksWithoutBucketsWithIrrationalBucketEdges") { - // Verify the case of buckets with irrational edges. See #SPARK-2862. + test("WorksWithoutBucketsWithNonIntegralBucketEdges") { + // Verify the case of buckets with nonintegral edges. See #SPARK-2862. val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(9) + // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ... val expectedHistogramResults = - Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + Array(11, 10, 10, 11, 10, 10, 11, 10, 11) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets(0) === 6.0) assert(histogramBuckets(9) === 99.0) } + test("WorksWithHugeRange") { + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val histogramResults = rdd.histogram(1000000)._2 + assert(histogramResults(0) === 1) + assert(histogramResults(1) === 1) + assert(histogramResults.last === 1) + assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0)) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 3fdbe99b5d02b..ecd1cba5b5abe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -170,8 +170,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") @@ -413,8 +413,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach t2.join() t3.join() - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -1223,4 +1223,30 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + val result = memoryStore.putBytes(blockId, 13000, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + assert(result.size === 13000) + assert(result.data === null) + assert(result.droppedBlocks === Nil) + } + + test("put a small ByteBuffer to MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + var bytes: ByteBuffer = null + val result = memoryStore.putBytes(blockId, 10000, () => { + bytes = ByteBuffer.allocate(10000) + bytes + }) + assert(result.size === 10000) + assert(result.data === Right(bytes)) + assert(result.droppedBlocks === Nil) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 730a4b54f5aa1..c0c28cb60e21d 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.util.Properties + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -44,11 +46,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc SparkListenerStageCompleted(stageInfo) } - private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + private def createJobStartEvent( + jobId: Int, + stageIds: Seq[Int], + jobGroup: Option[String] = None): SparkListenerJobStart = { val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) + val properties: Option[Properties] = jobGroup.map { groupId => + val props = new Properties() + props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + props + } + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { @@ -110,6 +120,23 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.stageIdToActiveJobIds.size should be (0) } + test("test clearing of jobGroupToJobIds") { + val conf = new SparkConf() + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run 50 jobs, each with one stage + for (jobId <- 0 to 50) { + listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) + listener.onStageSubmitted(createStageStartEvent(0)) + listener.onStageCompleted(createStageEndEvent(0, failed = false)) + listener.onJobEnd(createJobEndEvent(jobId, false)) + } + assertActiveJobsStateIsEmpty(listener) + // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs + listener.jobGroupToJobIds.size should be (5) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 6a70877356409..ef890d2ba60f3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toString) === i.toString) } } + + test("contains") { + val map = new OpenHashMap[String, Int](2) + map("a") = 1 + assert(map.contains("a")) + assert(!map.contains("b")) + assert(!map.contains(null)) + map(null) = 0 + assert(map.contains(null)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 8c7df7d73dcd3..caf378fec8b3e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toLong) === i.toString) } } + + test("contains") { + val map = new PrimitiveKeyOpenHashMap[Int, Int](1) + map(0) = 0 + assert(map.contains(0)) + assert(!map.contains(1)) + } } diff --git a/dev/run-tests b/dev/run-tests index d6935a61c6d29..561d7fc9e7b1f 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -178,6 +178,15 @@ CURRENT_BLOCK=$BLOCK_BUILD fi } +echo "" +echo "=========================================================================" +echo "Detecting binary incompatibilities with MiMa" +echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + +./dev/mima + echo "" echo "=========================================================================" echo "Running Spark unit tests" @@ -227,12 +236,3 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS ./python/run-tests - -echo "" -echo "=========================================================================" -echo "Detecting binary incompatibilities with MiMa" -echo "=========================================================================" - -CURRENT_BLOCK=$BLOCK_MIMA - -./dev/mima diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index 1348e0609dda4..8ab6db6925d6e 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -22,6 +22,6 @@ readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 readonly BLOCK_BUILD=14 -readonly BLOCK_SPARK_UNIT_TESTS=15 -readonly BLOCK_PYSPARK_UNIT_TESTS=16 -readonly BLOCK_MIMA=17 +readonly BLOCK_MIMA=15 +readonly BLOCK_SPARK_UNIT_TESTS=16 +readonly BLOCK_PYSPARK_UNIT_TESTS=17 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 5f4000e83925c..3a937b637e003 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -199,12 +199,12 @@ done failing_test="Python style tests" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then failing_test="Spark unit tests" elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then failing_test="PySpark unit tests" - elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then - failing_test="MiMa tests" else failing_test="some tests" fi diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index c601d793a2e9a..3f10cb2dc3d2a 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -899,6 +899,8 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { // Transform the values without changing the ids (preserves the internal index) def mapValues[VD2](map: VD => VD2): VertexRDD[VD2] def mapValues[VD2](map: (VertexId, VD) => VD2): VertexRDD[VD2] + // Show only vertices unique to this set based on their VertexId's + def minus(other: RDD[(VertexId, VD)]) // Remove vertices from this set that appear in the other set def diff(other: VertexRDD[VD]): VertexRDD[VD] // Join operators that take advantage of the internal indexing to accelerate joins (substantially) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index da6aef7f14c4c..c08c76d226713 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -408,31 +408,31 @@ import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. public class Document implements Serializable { - private Long id; + private long id; private String text; - public Document(Long id, String text) { + public Document(long id, String text) { this.id = id; this.text = text; } - public Long getId() { return this.id; } - public void setId(Long id) { this.id = id; } + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } public String getText() { return this.text; } public void setText(String text) { this.text = text; } } public class LabeledDocument extends Document implements Serializable { - private Double label; + private double label; - public LabeledDocument(Long id, String text, Double label) { + public LabeledDocument(long id, String text, double label) { super(id, text); this.label = label; } - public Double getLabel() { return this.label; } - public void setLabel(Double label) { this.label = label; } + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } } // Set up contexts. @@ -565,6 +565,11 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from case classes. +case class LabeledDocument(id: Long, text: String, label: Double) +case class Document(id: Long, text: String) + val conf = new SparkConf().setAppName("CrossValidatorExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) @@ -655,6 +660,36 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 0b6db4fcb7b1f..f5aa15b7d9b79 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model. {% highlight scala %} import org.apache.spark.mllib.clustering.GaussianMixture +import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) // Cluster the data into two classes using GaussianMixture val gmm = new GaussianMixture().setK(2).run(parsedData) +// Save and load model +gmm.save(sc, "myGMMModel") +val sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + // output parameters of max-likelihood model for (i <- 0 until gmm.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format @@ -231,6 +236,9 @@ public class GaussianMixtureExample { // Cluster the data into two classes using GaussianMixture GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + // Save and load GaussianMixtureModel + gmm.save(sc, "myGMMModel") + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") // Output the parameters of the mixture model for(int j=0; j diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68b1aeb8ebd01..d9f3eb2b74b18 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -274,6 +274,6 @@ If you need a reference to the proper location to put log files in the YARN so t # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. +- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6a333fdb562a7..4441d6a000a02 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -624,7 +624,8 @@ tuples or lists in the RDD created in the step 1. For example: {% highlight python %} # Import SQLContext and data types -from pyspark.sql import * +from pyspark.sql import SQLContext +from pyspark.sql.types import * # sc is an existing SparkContext. sqlContext = SQLContext(sc) @@ -1405,7 +1406,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load("jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") {% endhighlight %} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 322de7bf2fed8..51d273af8da84 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -28,6 +28,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline @@ -40,7 +41,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { @@ -76,7 +76,8 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Find a free port */ private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 0f3298af6234a..24d78ecb3a97d 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -25,6 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.activemq.broker.{TransportConnector, BrokerService} +import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence @@ -113,7 +114,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { } private def findFreePort(): Int = { - Utils.startServiceOnPort(23456, (trialPort: Int) => { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index ad4bfe077293a..a9f04b559c3d1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -121,6 +121,22 @@ abstract class VertexRDD[VD]( */ def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other an RDD to run the set operation against + */ + def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] + + /** + * For each VertexId present in both `this` and `other`, minus will act as a set difference + * operation returning only those unique VertexId's present in `this`. + * + * @param other a VertexRDD to run the set operation against + */ + def minus(other: VertexRDD[VD]): VertexRDD[VD] + /** * For each vertex present in both `this` and `other`, `diff` returns only those vertices with * differing values; for values that are different, keeps the values from `other`. This is diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 4fd2548b7faf6..b90f9fa327052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -88,6 +88,21 @@ private[graphx] abstract class VertexPartitionBaseOps this.withMask(newMask) } + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Self[VD]): Self[VD] = { + if (self.index != other.index) { + logWarning("Minus operations on two VertexPartitions with different indexes is slow.") + minus(createUsingIndex(other.iterator)) + } else { + self.withMask(self.mask.andNot(other.mask)) + } + } + + /** Hides the VertexId's that are the same between `this` and `other`. */ + def minus(other: Iterator[(VertexId, VD)]): Self[VD] = { + minus(createUsingIndex(other)) + } + /** * Hides vertices that are the same between this and other. For vertices that are different, keeps * the values from `other`. The indices of `this` and `other` must be the same. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 125692ddaad83..349c8545bf201 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -103,6 +103,31 @@ class VertexRDDImpl[VD] private[graphx] ( override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = this.mapVertexPartitions(_.map(f)) + override def minus(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { + minus(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) + } + + override def minus (other: VertexRDD[VD]): VertexRDD[VD] = { + other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true) { + (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.minus(otherPart)) + }) + case _ => + this.withPartitionsRDD[VD]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.minus(msgs)) + } + ) + } + } + override def diff(other: RDD[(VertexId, VD)]): VertexRDD[VD] = { diff(this.aggregateUsingIndex(other, (a: VD, b: VD) => a)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index 4f7a442ab503d..c9443d11c76cf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -47,6 +47,35 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("minus") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB = VertexRDD(sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1))).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + } + } + + test("minus with RDD[(VertexId, VD)]") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 2).map(i => (i.toLong, 0))).cache() + val vertexB: RDD[(VertexId, Int)] = + sc.parallelize(25 until 100, 2).map(i => (i.toLong, 1)).cache() + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 25).toSet) + } + } + + test("minus with non-equal number of partitions") { + withSpark { sc => + val vertexA = VertexRDD(sc.parallelize(0 until 75, 5).map(i => (i.toLong, 0))) + val vertexB = VertexRDD(sc.parallelize(50 until 100, 2).map(i => (i.toLong, 1))) + assert(vertexA.partitions.size != vertexB.partitions.size) + val vertexC = vertexA.minus(vertexB) + assert(vertexC.map(_._1).collect.toSet === (0 until 50).toSet) + } + } + test("diff") { withSpark { sc => val n = 100 @@ -71,7 +100,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } - test("diff vertices with the non-equal number of partitions") { + test("diff vertices with non-equal number of partitions") { withSpark { sc => val vertexA = VertexRDD(sc.parallelize(0 until 24, 3).map(i => (i.toLong, 0))) val vertexB = VertexRDD(sc.parallelize(8 until 16, 2).map(i => (i.toLong, 1))) @@ -96,7 +125,7 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } - test("leftJoin vertices with the non-equal number of partitions") { + test("leftJoin vertices with non-equal number of partitions") { withSpark { sc => val vertexA = VertexRDD(sc.parallelize(0 until 100, 2).map(i => (i.toLong, 1))) val vertexB = VertexRDD( diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index dc90e9e987234..2da5f7278729e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -147,7 +147,6 @@ void addOptionString(List cmd, String options) { */ List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); - String scala = getScalaVersion(); List cp = new ArrayList(); addToClassPath(cp, getenv("SPARK_CLASSPATH")); @@ -158,6 +157,7 @@ List buildClassPath(String appClassPath) throws IOException { boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES")); boolean isTesting = "1".equals(getenv("SPARK_TESTING")); if (prependClasses || isTesting) { + String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher"); @@ -182,7 +182,7 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - String assembly = findAssembly(scala); + String assembly = findAssembly(); addToClassPath(cp, assembly); // When Hive support is needed, Datanucleus jars must be included on the classpath. Datanucleus @@ -330,7 +330,7 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } - private String findAssembly(String scalaVersion) { + private String findAssembly() { String sparkHome = getSparkHome(); File libdir; if (new File(sparkHome, "RELEASE").isFile()) { @@ -338,7 +338,7 @@ private String findAssembly(String scalaVersion) { checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", scalaVersion)); + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); } final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); diff --git a/make-distribution.sh b/make-distribution.sh index 8162fe94c1af0..9ed1abfe8c598 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.1" +TACHYON_VERSION="0.5.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0b1f90daa7d8e..68401e36950bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param} import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** @@ -39,3 +39,67 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { override protected def outputDataType: DataType = new ArrayType(StringType, false) } + +/** + * :: AlphaComponent :: + * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) + * or using it to split the text (set matching to false). Optional parameters also allow to fold + * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length. + * It returns an array of strings that can be empty. + * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true, + * lowercase = false, minTokenLength = 1 + */ +@AlphaComponent +class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + + /** + * param for minimum token length, default is one to avoid returning empty strings + * @group param + */ + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + + /** @group setParam */ + def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) + + /** @group getParam */ + def getMinTokenLength: Int = get(minTokenLength) + + /** + * param sets regex as splitting on gaps (true) or matching tokens (false) + * @group param + */ + val gaps: BooleanParam = new BooleanParam( + this, "gaps", "Set regex to match gaps or tokens", Some(false)) + + /** @group setParam */ + def setGaps(value: Boolean): this.type = set(gaps, value) + + /** @group getParam */ + def getGaps: Boolean = get(gaps) + + /** + * param sets regex pattern used by tokenizer + * @group param + */ + val pattern: Param[String] = new Param( + this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + + /** @group setParam */ + def setPattern(value: String): this.type = set(pattern, value) + + /** @group getParam */ + def getPattern: String = get(pattern) + + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => + val re = paramMap(pattern).r + val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = paramMap(minTokenLength) + tokens.filter(_.length >= minLength) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 15ca2547d56a8..e39156734794c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -111,9 +111,11 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) + .setValidateData(validateData) lrAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -135,8 +137,12 @@ private[python] class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lassoAlg = new LassoWithSGD() + lassoAlg.setIntercept(intercept) + .setValidateData(validateData) lassoAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -157,8 +163,12 @@ private[python] class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.setIntercept(intercept) + .setValidateData(validateData) ridgeAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index af6f83c74bb40..ec65a3da689de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: @@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD @Experimental class GaussianMixtureModel( val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable { + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") - + + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) + } + /** Number of gaussians in mixture */ def k: Int = weights.length @@ -83,5 +95,79 @@ class GaussianMixtureModel( p(i) /= pSum } p - } + } +} + +@Experimental +object GaussianMixtureModel extends Loader[GaussianMixtureModel] { + + private object SaveLoadV1_0 { + + case class Data(weight: Double, mu: Vector, sigma: Matrix) + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel" + + def save( + sc: SparkContext, + path: String, + weights: Array[Double], + gaussians: Array[MultivariateGaussian]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(weights.length) { i => + Data(weights(i), gaussians(i).mu, gaussians(i).sigma) + } + sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): GaussianMixtureModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val (weights, gaussians) = dataArray.map { + case Row(weight: Double, mu: Vector, sigma: Matrix) => + (weight, new MultivariateGaussian(mu, sigma)) + }.unzip + + return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + } + } + + override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val k = (metadata \ "k").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, version) match { + case (classNameV1_0, "1.0") => { + val model = SaveLoadV1_0.load(sc, path) + require(model.weights.length == k, + s"GaussianMixtureModel requires weights of length $k " + + s"got weights of length ${model.weights.length}") + require(model.gaussians.length == k, + s"GaussianMixtureModel requires gaussians of length $k" + + s"got gaussians of length ${model.gaussians.length}") + model + } + case _ => throw new Exception( + s"GaussianMixtureModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 5e17c8da61134..9d63a08e211bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV, normalize} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 849f44295f089..d1a174063caba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -187,6 +187,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def hashCode(): Int = 1994 + override def typeName: String = "matrix" + private[spark] override def asNullable: MatrixUDT = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 2cda9b252ee06..328dbe2ce11fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -185,6 +185,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def hashCode: Int = 7919 + override def typeName: String = "vector" + private[spark] override def asNullable: VectorUDT = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 45b9ebb4cc0d6..9fd60ff7a0c79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java new file mode 100644 index 0000000000000..3806f650025b2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaTokenizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void regexTokenizer() { + RegexTokenizer myRegExTokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(3); + + JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + )); + DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + + Row[] pairs = myRegExTokenizer.transform(dataset) + .select("tokens", "wantedTokens") + .collect(); + + for (Row r : pairs) { + Assert.assertEquals(r.get(0), r.get(1)); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala new file mode 100644 index 0000000000000..bf862b912d326 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +@BeanInfo +case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { + /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ + def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) +} + +class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("RegexTokenizer") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + + val dataset0 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + )) + testRegexTokenizer(tokenizer, dataset0) + + val dataset1 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Seq("punct")) + )) + + tokenizer.setMinTokenLength(3) + testRegexTokenizer(tokenizer, dataset1) + + tokenizer + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(0) + val dataset2 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + )) + testRegexTokenizer(tokenizer, dataset2) + } +} + +object RegexTokenizerSuite extends FunSuite { + + def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("tokens", "wantedTokens") + .collect() + .foreach { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index aaa81da9e273c..a26c52852c4d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD) + val numFeatures = testRDD.map(_.features.size).first() + val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2)) + val model2 = lr.run(testRDD, initialWeights) + + LogisticRegressionSuite.checkModelsEqual(model, model2) + /** * The following is the instruction to reproduce the model using R's glmnet package. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 1b46a4012d731..f356ffa3e3a26 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { test("single cluster") { @@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } test("two clusters") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) + val data = sc.parallelize(GaussianTestData.data) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } test("two clusters with sparse data") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) - + val data = sc.parallelize(GaussianTestData.data) val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + + test("model save / load") { + val data = sc.parallelize(GaussianTestData.data) + + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + gmm.save(sc, path) + + // TODO: GaussianMixtureModel should implement equals/hashcode directly. + val sameModel = GaussianMixtureModel.load(sc, path) + assert(sameModel.k === gmm.k) + (0 until sameModel.k).foreach { i => + assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu) + assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + + object GaussianTestData { + + val data = Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + ) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 96f677db3f377..0d2cec58e2c03 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -436,5 +436,7 @@ class MatricesSuite extends FunSuite { Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach { mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray) } + assert(mUDT.typeName == "matrix") + assert(mUDT.simpleString == "matrix") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 5def899cea117..2839c4c289b2d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -187,6 +187,8 @@ class VectorsSuite extends FunSuite { for (v <- Seq(dv0, dv1, sv0, sv1)) { assert(v === udt.deserialize(udt.serialize(v))) } + assert(udt.typeName == "vector") + assert(udt.simpleString == "vector") } test("fromBreeze") { diff --git a/pom.xml b/pom.xml index 23bb16130b504..b3cecd1893a06 100644 --- a/pom.xml +++ b/pom.xml @@ -1452,7 +1452,8 @@ ${basedir}/src/test/scala scalastyle-config.xml scalastyle-output.xml - UTF-8 + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 328d59485a731..b9f40046e15a2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,7 +44,16 @@ object MimaExcludes { // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.rdd.JdbcRDD.compute"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") + ) ++ Seq( + // SPARK-6510 Add a Graph#minus method acting as Set#difference + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") ) case v if v.startsWith("1.3") => diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 414a0ada80787..209f1ee473b5b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -140,6 +140,13 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", + ... intercept=True, validateData=True) + >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( @@ -173,7 +180,8 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.0, regType=None, intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False, + validateData=True): """ Train a linear regression model on the given data. @@ -195,15 +203,18 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - @param intercept: Boolean parameter which indicates the use + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). (default: False) + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept)) + regType, bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -253,6 +264,13 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( @@ -273,11 +291,13 @@ class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a Lasso regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -327,6 +347,13 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( @@ -347,11 +374,13 @@ class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a ridge regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cb89da7a8ed5..d51309f7ef5aa 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -520,6 +520,25 @@ def sort(self, *cols): orderBy = sort + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + >>> df.describe().show() + summary age + count 2 + mean 3.5 + stddev 1.5 + min 2 + max 5 + """ + cols = ListConverter().convert(cols, + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + return DataFrame(jdf, self.sql_ctx) + def head(self, n=None): """ Return the first `n` rows or the first row if n is None. @@ -985,6 +1004,23 @@ def substr(self, startPos, length): __getslice__ = substr + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + # order asc = _unary_op("asc", "Returns a sort expression based on the" " ascending order of the given column name.") diff --git a/repl/pom.xml b/repl/pom.xml index edfa1c7f2c29c..03053b4c3b287 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -84,6 +84,11 @@ scalacheck_${scala.binary.version} test
+ + org.mockito + mockito-all + test + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 9805609120005..004941d5f50ae 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,9 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException} -import java.net.{URI, URL, URLEncoder} -import java.util.concurrent.{Executors, ExecutorService} +import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} + +import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val parentLoader = new ParentClassLoader(parent) + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { if (Set("http", "https", "ftp").contains(uri.getScheme)) { @@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), + SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + if (fileSystem.exists(path)) { + fileSystem.open(path) + } else { + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null try { - val pathInDirectory = name.replace('.', '/') + ".class" - val inputStream = { + inputStream = { if (fileSystem != null) { - fileSystem.open(new Path(directory, pathInDirectory)) + getClassFileInputStreamFromFileSystem(pathInDirectory) } else { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - - Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager) - .getInputStream + getClassFileInputStreamFromHttpServer(pathInDirectory) } } val bytes = readAndTransformClass(name, inputStream) - inputStream.close() Some(defineClass(name, bytes, 0, bytes.length)) } catch { - case e: FileNotFoundException => + case e: ClassNotFoundException => // We did not find the class logDebug(s"Did not load class $name from REPL class server at $uri", e) None @@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Something bad happened while checking if the class exists logError(s"Failed to check existence of class $name on REPL class server at $uri", e) None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 6a79e76a34db8..c709cde740748 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -20,13 +20,25 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.scalatest.concurrent.Interruptor +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, TestUtils} +import org.apache.spark._ import org.apache.spark.util.Utils -class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { +class ExecutorClassLoaderSuite + extends FunSuite + with BeforeAndAfterAll + with MockitoSugar + with Logging { val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") @@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { var tempDir2: File = _ var url1: String = _ var urls2: Array[URL] = _ + var classServer: HttpServer = _ override def beforeAll() { super.beforeAll() @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { super.afterAll() + if (classServer != null) { + classServer.stop() + } Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) } test("child first") { @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { } } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { + // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class + // from the driver's class server would leak a HTTP connection, causing the class server's + // thread / connection pool to be exhausted. + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + classServer = new HttpServer(conf, tempDir1, securityManager) + classServer.start() + // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this + val mockEnv = mock[SparkEnv] + when(mockEnv.securityManager).thenReturn(securityManager) + SparkEnv.set(mockEnv) + // Create an ExecutorClassLoader that's configured to load classes from the HTTP server + val parentLoader = new URLClassLoader(Array.empty, null) + val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + classLoader.httpUrlConnectionTimeoutMillis = 500 + // Check that this class loader can actually load classes that exist + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + // Try to perform a full GC now, since GC during the test might mask resource leaks + System.gc() + // When the original bug occurs, the test thread becomes blocked in a classloading call + // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to + // shut down the HTTP server when the test times out + val interruptor: Interruptor = new Interruptor { + override def apply(thread: Thread): Unit = { + classServer.stop() + classServer = null + thread.interrupt() + } + } + def tryAndFailToLoadABunchOfClasses(): Unit = { + // The number of trials here should be much larger than Jetty's thread / connection limit + // in order to expose thread or connection leaks + for (i <- 1 to 1000) { + if (Thread.currentThread().isInterrupted) { + throw new InterruptedException() + } + // Incorporate the iteration number into the class name in order to avoid any response + // caching that might be added in the future + intercept[ClassNotFoundException] { + classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance() + } + } + } + failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) + } + } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 0ff521706c71a..459a5035d4984 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -137,9 +137,9 @@ - + - + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 15add84878ecf..34fedead44db3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -30,6 +30,12 @@ class AnalysisException protected[sql] ( val startPosition: Option[Int] = None) extends Exception with Serializable { + def withPosition(line: Option[Int], startPosition: Option[Int]) = { + val newException = new AnalysisException(message, line, startPosition) + newException.setStackTrace(getStackTrace) + newException + } + override def getMessage: String = { val lineAnnotation = line.map(l => s" line $l").getOrElse("") val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 366be00473d1c..3823584287741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] object KeywordNormalizer { - def apply(str: String) = str.toLowerCase() + def apply(str: String): String = str.toLowerCase() } private[sql] abstract class AbstractSparkSQLParser @@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize = KeywordNormalizer(str) + def normalize: String = KeywordNormalizer(str) def parser: Parser[String] = normalize } @@ -81,7 +81,7 @@ private[sql] abstract class AbstractSparkSQLParser class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { - override def toString = chars + override def toString: String = chars } /* This is a work around to support the lazy setting */ @@ -120,7 +120,7 @@ class SqlLexical extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') + override def identChar: Parser[Elem] = letter | elem('_') override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 92d3db077c5e1..44eceb0b372e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -64,9 +64,7 @@ class Analyzer(catalog: Catalog, UnresolvedHavingClauseAttributes :: TrimGroupingAliases :: typeCoercionRules ++ - extendedResolutionRules : _*), - Batch("Remove SubQueries", fixedPoint, - EliminateSubQueries) + extendedResolutionRules : _*) ) /** @@ -170,12 +168,12 @@ class Analyzer(catalog: Catalog, * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation) = { + def getTable(u: UnresolvedRelation): LogicalPlan = { try { catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"no such table ${u.tableIdentifier}") + u.failAnalysis(s"no such table ${u.tableName}") } } @@ -275,7 +273,8 @@ class Analyzer(catalog: Catalog, q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolveChildren(name, resolver).getOrElse(u) + val result = + withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 9e6e2912e0622..5eb7dff0cede8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -86,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables += ((getDbTableName(tableIdent), plan)) } - override def unregisterTable(tableIdentifier: Seq[String]) = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) tables -= getDbTableName(tableIdent) } - override def unregisterAllTables() = { + override def unregisterAllTables(): Unit = { tables.clear() } @@ -147,8 +147,8 @@ trait OverrideCatalog extends Catalog { } abstract override def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan = { + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val overriddenTable = overrides.get(getDBTable(tableIdent)) val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) @@ -205,15 +205,15 @@ trait OverrideCatalog extends Catalog { */ object EmptyCatalog extends Catalog { - val caseSensitive: Boolean = true + override val caseSensitive: Boolean = true - def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdentifier: Seq[String]): Boolean = { throw new UnsupportedOperationException } - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None) = { + override def lookupRelation( + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { throw new UnsupportedOperationException } @@ -221,11 +221,11 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4e8fc892f3eea..40472a1cbb3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -33,7 +33,7 @@ class CheckAnalysis { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil - def failAnalysis(msg: String) = { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } @@ -63,7 +63,7 @@ class CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) => + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK case e: Attribute if !groupingExprs.contains(e) => @@ -85,14 +85,18 @@ class CheckAnalysis { cleaned.foreach(checkValidAggregateExpression) - case o if o.children.nonEmpty && - !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => - val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") - val input = o.inputSet.map(_.prettyString).mkString(",") + case _ => // Fallbacks to the following checks + } + + operator match { + case o if o.children.nonEmpty && o.missingInput.nonEmpty => + val missingAttributes = o.missingInput.mkString(",") + val input = o.inputSet.mkString(",") - failAnalysis(s"resolved attributes $missingAttributes missing from $input") + failAnalysis( + s"resolved attribute(s) $missingAttributes missing from $input " + + s"in operator ${operator.simpleString}") - // Catch all case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9f334f6d42ad1..c43ea55899695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -35,7 +35,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry { val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -47,7 +47,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry { class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -61,13 +61,15 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder) = ??? + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + throw new UnsupportedOperationException + } - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - def caseSensitive: Boolean = ??? + override def caseSensitive: Boolean = throw new UnsupportedOperationException } /** @@ -76,7 +78,7 @@ object EmptyFunctionRegistry extends FunctionRegistry { * TODO move this into util folder? */ object StringKeyHashMap { - def apply[T](caseSensitive: Boolean) = caseSensitive match { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { case false => new StringKeyHashMap[T](_.toLowerCase) case true => new StringKeyHashMap[T](identity) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index e95f19e69ed43..c61c395cb4bb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -38,8 +38,16 @@ package object analysis { implicit class AnalysisErrorAt(t: TreeNode[_]) { /** Fails the analysis at the point where a specific tree node was parsed. */ - def failAnalysis(msg: String) = { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } } + + /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ + def withPosition[A](t: TreeNode[_])(f: => A) = { + try f catch { + case a: AnalysisException => + throw a.withPosition(t.origin.line, t.origin.startPosition) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a7cd4124e56f3..300e9ba187bc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.DataType /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -36,7 +37,12 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str case class UnresolvedRelation( tableIdentifier: Seq[String], alias: Option[String] = None) extends LeafNode { - override def output = Nil + + /** Returns a `.` separated name for this relation. */ + def tableName: String = tableIdentifier.mkString(".") + + override def output: Seq[Attribute] = Nil + override lazy val resolved = false } @@ -44,16 +50,16 @@ case class UnresolvedRelation( * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = UnresolvedAttribute(name) + override def newInstance(): UnresolvedAttribute = this + override def withNullability(newNullability: Boolean): UnresolvedAttribute = this + override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -63,16 +69,16 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo } case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"'$name(${children.mkString(",")})" + override def toString: String = s"'$name(${children.mkString(",")})" } /** @@ -82,17 +88,17 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E trait Star extends Attribute with trees.LeafNode[Expression] { self: Product => - override def name = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def name: String = throw new UnresolvedException(this, "name") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = this + override def newInstance(): Star = this + override def withNullability(newNullability: Boolean): Star = this + override def withQualifiers(newQualifiers: Seq[String]): Star = this + override def withName(newName: String): Star = this // Star gets expanded at runtime so we never evaluate a Star. override def eval(input: Row = null): EvaluatedType = @@ -125,7 +131,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { } } - override def toString = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = table.map(_ + ".").getOrElse("") + "*" } /** @@ -140,25 +146,25 @@ case class UnresolvedStar(table: Option[String]) extends Star { case class MultiAlias(child: Expression, names: Seq[String]) extends Attribute with trees.UnaryNode[Expression] { - override def name = throw new UnresolvedException(this, "name") + override def name: String = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this + override def newInstance(): MultiAlias = this - override def withNullability(newNullability: Boolean) = this + override def withNullability(newNullability: Boolean): MultiAlias = this - override def withQualifiers(newQualifiers: Seq[String]) = this + override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this - override def withName(newName: String) = this + override def withName(newName: String): MultiAlias = this override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @@ -175,17 +181,17 @@ case class MultiAlias(child: Expression, names: Seq[String]) */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions - override def toString = expressions.mkString("ResolvedStar(", ", ", ")") + override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child.$fieldName" + override def toString: String = s"$child.$fieldName" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 51a09ac0e1249..145f062dd6817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -61,60 +61,60 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- = UnaryMinus(expr) - def unary_! = Not(expr) - def unary_~ = BitwiseNot(expr) - - def + (other: Expression) = Add(expr, other) - def - (other: Expression) = Subtract(expr, other) - def * (other: Expression) = Multiply(expr, other) - def / (other: Expression) = Divide(expr, other) - def % (other: Expression) = Remainder(expr, other) - def & (other: Expression) = BitwiseAnd(expr, other) - def | (other: Expression) = BitwiseOr(expr, other) - def ^ (other: Expression) = BitwiseXor(expr, other) - - def && (other: Expression) = And(expr, other) - def || (other: Expression) = Or(expr, other) - - def < (other: Expression) = LessThan(expr, other) - def <= (other: Expression) = LessThanOrEqual(expr, other) - def > (other: Expression) = GreaterThan(expr, other) - def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = EqualTo(expr, other) - def <=> (other: Expression) = EqualNullSafe(expr, other) - def !== (other: Expression) = Not(EqualTo(expr, other)) - - def in(list: Expression*) = In(expr, list) - - def like(other: Expression) = Like(expr, other) - def rlike(other: Expression) = RLike(expr, other) - def contains(other: Expression) = Contains(expr, other) - def startsWith(other: Expression) = StartsWith(expr, other) - def endsWith(other: Expression) = EndsWith(expr, other) - def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def unary_- : Expression= UnaryMinus(expr) + def unary_! : Predicate = Not(expr) + def unary_~ : Expression = BitwiseNot(expr) + + def + (other: Expression): Expression = Add(expr, other) + def - (other: Expression): Expression = Subtract(expr, other) + def * (other: Expression): Expression = Multiply(expr, other) + def / (other: Expression): Expression = Divide(expr, other) + def % (other: Expression): Expression = Remainder(expr, other) + def & (other: Expression): Expression = BitwiseAnd(expr, other) + def | (other: Expression): Expression = BitwiseOr(expr, other) + def ^ (other: Expression): Expression = BitwiseXor(expr, other) + + def && (other: Expression): Predicate = And(expr, other) + def || (other: Expression): Predicate = Or(expr, other) + + def < (other: Expression): Predicate = LessThan(expr, other) + def <= (other: Expression): Predicate = LessThanOrEqual(expr, other) + def > (other: Expression): Predicate = GreaterThan(expr, other) + def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other) + def === (other: Expression): Predicate = EqualTo(expr, other) + def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) + def !== (other: Expression): Predicate = Not(EqualTo(expr, other)) + + def in(list: Expression*): Expression = In(expr, list) + + def like(other: Expression): Expression = Like(expr, other) + def rlike(other: Expression): Expression = RLike(expr, other) + def contains(other: Expression): Expression = Contains(expr, other) + def startsWith(other: Expression): Expression = StartsWith(expr, other) + def endsWith(other: Expression): Expression = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def isNull = IsNull(expr) - def isNotNull = IsNotNull(expr) + def isNull: Predicate = IsNull(expr) + def isNotNull: Predicate = IsNotNull(expr) - def getItem(ordinal: Expression) = GetItem(expr, ordinal) - def getField(fieldName: String) = UnresolvedGetField(expr, fieldName) + def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal) + def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName) - def cast(to: DataType) = Cast(expr, to) + def cast(to: DataType): Expression = Cast(expr, to) - def asc = SortOrder(expr, Ascending) - def desc = SortOrder(expr, Descending) + def asc: SortOrder = SortOrder(expr, Ascending) + def desc: SortOrder = SortOrder(expr, Descending) - def as(alias: String) = Alias(expr, alias)() - def as(alias: Symbol) = Alias(expr, alias.name)() + def as(alias: String): NamedExpression = Alias(expr, alias)() + def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { implicit class DslExpression(e: Expression) extends ImplicitOperators { - def expr = e + def expr: Expression = e } implicit def booleanToLiteral(b: Boolean): Literal = Literal(b) @@ -144,94 +144,100 @@ package object dsl { } } - def sum(e: Expression) = Sum(e) - def sumDistinct(e: Expression) = SumDistinct(e) - def count(e: Expression) = Count(e) - def countDistinct(e: Expression*) = CountDistinct(e) - def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) - def avg(e: Expression) = Average(e) - def first(e: Expression) = First(e) - def last(e: Expression) = Last(e) - def min(e: Expression) = Min(e) - def max(e: Expression) = Max(e) - def upper(e: Expression) = Upper(e) - def lower(e: Expression) = Lower(e) - def sqrt(e: Expression) = Sqrt(e) - def abs(e: Expression) = Abs(e) - - implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + def sum(e: Expression): Expression = Sum(e) + def sumDistinct(e: Expression): Expression = SumDistinct(e) + def count(e: Expression): Expression = Count(e) + def countDistinct(e: Expression*): Expression = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = + ApproxCountDistinct(e, rsd) + def avg(e: Expression): Expression = Average(e) + def first(e: Expression): Expression = First(e) + def last(e: Expression): Expression = Last(e) + def min(e: Expression): Expression = Min(e) + def max(e: Expression): Expression = Max(e) + def upper(e: Expression): Expression = Upper(e) + def lower(e: Expression): Expression = Lower(e) + def sqrt(e: Expression): Expression = Sqrt(e) + def abs(e: Expression): Expression = Abs(e) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) - def attr = analysis.UnresolvedAttribute(s) + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) } abstract class ImplicitAttribute extends ImplicitOperators { def s: String - def expr = attr - def attr = analysis.UnresolvedAttribute(s) + def expr: UnresolvedAttribute = attr + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) /** Creates a new AttributeReference of type boolean */ - def boolean = AttributeReference(s, BooleanType, nullable = true)() + def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)() /** Creates a new AttributeReference of type byte */ - def byte = AttributeReference(s, ByteType, nullable = true)() + def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)() /** Creates a new AttributeReference of type short */ - def short = AttributeReference(s, ShortType, nullable = true)() + def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)() /** Creates a new AttributeReference of type int */ - def int = AttributeReference(s, IntegerType, nullable = true)() + def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)() /** Creates a new AttributeReference of type long */ - def long = AttributeReference(s, LongType, nullable = true)() + def long: AttributeReference = AttributeReference(s, LongType, nullable = true)() /** Creates a new AttributeReference of type float */ - def float = AttributeReference(s, FloatType, nullable = true)() + def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)() /** Creates a new AttributeReference of type double */ - def double = AttributeReference(s, DoubleType, nullable = true)() + def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)() /** Creates a new AttributeReference of type string */ - def string = AttributeReference(s, StringType, nullable = true)() + def string: AttributeReference = AttributeReference(s, StringType, nullable = true)() /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() + def date: AttributeReference = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + def decimal: AttributeReference = + AttributeReference(s, DecimalType.Unlimited, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal(precision: Int, scale: Int) = + def decimal(precision: Int, scale: Int): AttributeReference = AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ - def timestamp = AttributeReference(s, TimestampType, nullable = true)() + def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() /** Creates a new AttributeReference of type binary */ - def binary = AttributeReference(s, BinaryType, nullable = true)() + def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() /** Creates a new AttributeReference of type array */ - def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(dataType: DataType): AttributeReference = + AttributeReference(s, ArrayType(dataType), nullable = true)() /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) - def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + def map(mapType: MapType): AttributeReference = + AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) - def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() + def struct(structType: StructType): AttributeReference = + AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { - def notNull = a.withNullability(false) - def nullable = a.withNullability(true) + def notNull: AttributeReference = a.withNullability(false) + def nullable: AttributeReference = a.withNullability(true) // Protobuf terminology - def required = a.withNullability(false) + def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) + def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } @@ -241,23 +247,23 @@ package object dsl { abstract class LogicalPlanFunctions { def logicalPlan: LogicalPlan - def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression) = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, - condition: Option[Expression] = None) = + condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne case e => Alias(e, e.toString)() @@ -265,41 +271,43 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) def sample( fraction: Double, withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt) = + seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None) = + alias: Option[String] = None): LogicalPlan = Generate(generator, join, outer, None, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false) = + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) - def analyze = analysis.SimpleAnalyzer(logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String) = WriteToFile(path, logicalPlan) + def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) } } case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + def call(args: Expression*): ScalaUdf = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } } // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 82e760b6c6916..96a11e352ec50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + } } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index adaeab0b5c027..11b4eb5c888be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -19,27 +19,27 @@ package org.apache.spark.sql.catalyst.expressions protected class AttributeEquals(val a: Attribute) { - override def hashCode() = a match { + override def hashCode(): Int = a match { case ar: AttributeReference => ar.exprId.hashCode() case a => a.hashCode() } - override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match { case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId case (a1, a2) => a1 == a2 } } object AttributeSet { - def apply(a: Attribute) = - new AttributeSet(Set(new AttributeEquals(a))) + def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]) = + def apply(baseSet: Seq[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) .map(new AttributeEquals(_)).toSet) + } } /** @@ -57,8 +57,9 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ - override def equals(other: Any) = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + override def equals(other: Any): Boolean = other match { + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } @@ -81,32 +82,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in * `other`. */ - def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet) /** * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]) = + def --(other: Traversable[NamedExpression]): AttributeSet = new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found * in `other`. */ - def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet) /** * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to * true. */ - override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + override def filter(f: Attribute => Boolean): AttributeSet = + new AttributeSet(baseSet.filter(ae => f(ae.a))) /** * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in * `this` and `other`. */ - def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + def intersect(other: AttributeSet): AttributeSet = + new AttributeSet(baseSet.intersect(other.baseSet)) override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 76a9f08dea85f..2225621dbaabd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def toString = s"input[$ordinal]" + override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1bc858478ee1..31f1a5fdc7e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -29,9 +29,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) - override def foldable = child.foldable + override def foldable: Boolean = child.foldable - override def nullable = forceNullable(child.dataType, dataType) || child.nullable + override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true @@ -103,7 +103,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - override def toString = s"CAST($child, $dataType)" + override def toString: String = s"CAST($child, $dataType)" type EvaluatedType = Any @@ -394,10 +394,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - // TODO: This is very slow! - buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { - case (v, cast) => if (v == null) null else cast(v) - }: _*)) + // TODO: Could be faster? + val newRow = new GenericMutableRow(from.fields.size) + buildCast[Row](_, row => { + var i = 0 + while (i < row.length) { + val v = row(i) + newRow.update(i, if (v == null) null else casts(i)(v)) + i += 1 + } + newRow.copy() + }) } private[this] def cast(from: DataType, to: DataType): Any => Any = to match { @@ -430,14 +437,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w object Cast { // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { + override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } } // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { + override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6ad39b8372cfb..4e3bbc06a5b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -65,7 +65,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** * Returns a string representation of this expression that does not have developer centric @@ -84,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express def symbol: String - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def toString = s"($left $symbol $right)" + override def toString: String = s"($left $symbol $right)" } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -104,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = ??? - override def nullable = false - override def foldable = false - override def dataType = ??? + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = false + override def foldable: Boolean = false + override def dataType: DataType = throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index db5d897ee569f..c2866cd955409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericRow(outputArray) } - override def toString = s"Row => [${exprArray.mkString(",")}]" + override def toString: String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -107,12 +107,12 @@ class JoinedRow extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -142,7 +142,7 @@ class JoinedRow extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -153,7 +153,7 @@ class JoinedRow extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -207,12 +207,12 @@ class JoinedRow2 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -242,7 +242,7 @@ class JoinedRow2 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -253,7 +253,7 @@ class JoinedRow2 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -301,12 +301,12 @@ class JoinedRow3 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -336,7 +336,7 @@ class JoinedRow3 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -347,7 +347,7 @@ class JoinedRow3 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -395,12 +395,12 @@ class JoinedRow4 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -430,7 +430,7 @@ class JoinedRow4 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -441,7 +441,7 @@ class JoinedRow4 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -489,12 +489,12 @@ class JoinedRow5 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -524,7 +524,7 @@ class JoinedRow5 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -535,7 +535,7 @@ class JoinedRow5 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index b2c6d3029031d..f5fea3f015dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Random -import org.apache.spark.sql.types.DoubleType + +import org.apache.spark.sql.types.{DataType, DoubleType} case object Rand extends LeafExpression { - override def dataType = DoubleType - override def nullable = false + override def dataType: DataType = DoubleType + override def nullable: Boolean = false private[this] lazy val rand = new Random - override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def eval(input: Row = null): EvaluatedType = { + rand.nextDouble().asInstanceOf[EvaluatedType] + } - override def toString = "RAND()" + override def toString: String = "RAND()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 8a36c6810790d..389dc4f745723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true - override def toString = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off @@ -39,363 +39,669 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) - - s""" - case $x => - function.asInstanceOf[($anys) => Any]( - $evals) - """ + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _) + + s""" case $x => + val func = function.asInstanceOf[($anys) => Any] + $childs + (input: Row) => { + func( + $evals) + } + """ }.foreach(println) */ - - override def eval(input: Row): Any = { - val result = children.size match { - case 0 => function.asInstanceOf[() => Any]() - case 1 => - function.asInstanceOf[(Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) - - - case 2 => - function.asInstanceOf[(Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) - - - case 3 => - function.asInstanceOf[(Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) - - - case 4 => - function.asInstanceOf[(Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) - - - case 5 => - function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) - - - case 6 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) - - - case 7 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) - - - case 8 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) - - - case 9 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) - - - case 10 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) - - - case 11 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) - - - case 12 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) - - - case 13 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) - - - case 14 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) - - - case 15 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) - - - case 16 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) - - - case 17 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) - - - case 18 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) - - - case 19 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) - - - case 20 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) - - - case 21 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) - - - case 22 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), - ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) - - } - // scalastyle:on - - ScalaReflection.convertToCatalyst(result, dataType) + + val f = children.size match { + case 0 => + val func = function.asInstanceOf[() => Any] + (input: Row) => { + func() + } + + case 1 => + val func = function.asInstanceOf[(Any) => Any] + val child0 = children(0) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType)) + } + + case 2 => + val func = function.asInstanceOf[(Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType)) + } + + case 3 => + val func = function.asInstanceOf[(Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType)) + } + + case 4 => + val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType)) + } + + case 5 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType)) + } + + case 6 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType)) + } + + case 7 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType)) + } + + case 8 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType)) + } + + case 9 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType)) + } + + case 10 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType)) + } + + case 11 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType)) + } + + case 12 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType)) + } + + case 13 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType)) + } + + case 14 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType)) + } + + case 15 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType)) + } + + case 16 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType)) + } + + case 17 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType)) + } + + case 18 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType)) + } + + case 19 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType)) + } + + case 20 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType)) + } + + case 21 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType), + ScalaReflection.convertToScala(child20.eval(input), child20.dataType)) + } + + case 22 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + val child21 = children(21) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType), + ScalaReflection.convertToScala(child20.eval(input), child20.dataType), + ScalaReflection.convertToScala(child21.eval(input), child21.dataType)) + } } + + // scalastyle:on + + override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType) + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d00b2ac09745c..83074eb1e6310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.DataType abstract sealed class SortDirection case object Ascending extends SortDirection @@ -31,12 +32,12 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 21d714c9a8c3b..47b6f358ed1b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any def update(v: Any) - def copy(): this.type + def copy(): MutableValue } final class MutableInt extends MutableValue { var value: Int = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Int] + value = v.asInstanceOf[Int] } - def copy() = { + override def copy(): MutableInt = { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableInt] } } final class MutableFloat extends MutableValue { var value: Float = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Float] + value = v.asInstanceOf[Float] } - def copy() = { + override def copy(): MutableFloat = { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableFloat] } } final class MutableBoolean extends MutableValue { var value: Boolean = false - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Boolean] + value = v.asInstanceOf[Boolean] } - def copy() = { + override def copy(): MutableBoolean = { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableBoolean] } } final class MutableDouble extends MutableValue { var value: Double = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Double] + value = v.asInstanceOf[Double] } - def copy() = { + override def copy(): MutableDouble = { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableDouble] } } final class MutableShort extends MutableValue { var value: Short = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Short] } - def copy() = { + override def copy(): MutableShort = { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableShort] } } final class MutableLong extends MutableValue { var value: Long = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Long] } - def copy() = { + override def copy(): MutableLong = { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableLong] } } final class MutableByte extends MutableValue { var value: Byte = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Byte] } - def copy() = { + override def copy(): MutableByte = { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableByte] } } final class MutableAny extends MutableValue { var value: Any = _ - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Any] + value = v.asInstanceOf[Any] } - def copy() = { + override def copy(): MutableAny = { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableAny] } } @@ -234,9 +234,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) - override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5297d1e31246c..30da4faa3f1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -79,27 +79,29 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def nullable = base.nullable - override def dataType = base.dataType + override def nullable: Boolean = base.nullable + override def dataType: DataType = base.dataType def update(input: Row): Unit // Do we really need this? - override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance(): AggregateFunction = { + makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + } } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MIN($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) } - override def newInstance() = new MinFunction(child, this) + override def newInstance(): MinFunction = new MinFunction(child, this) } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MAX($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) } - override def newInstance() = new MaxFunction(child, this) + override def newInstance(): MaxFunction = new MaxFunction(child, this) } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT($child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } - override def newInstance() = new CountFunction(child, this) + override def newInstance(): CountFunction = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) - override def children = expressions + override def children: Seq[Expression] = expressions - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance() = new CountDistinctFunction(expressions, this) + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(expressions), "partialSets")() SplitEvaluation( CombineSetsAndCount(partialSet.toAttribute), @@ -185,11 +187,11 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) - override def children = expressions - override def nullable = false - override def dataType = ArrayType(expressions.head.dataType) - override def toString = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance() = new CollectHashSetFunction(expressions, this) + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -219,11 +221,13 @@ case class CollectHashSetFunction( case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { def this() = this(null) - override def children = inputSet :: Nil - override def nullable = false - override def dataType = LongType - override def toString = s"CombineAndCount($inputSet)" - override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"CombineAndCount($inputSet)" + override def newInstance(): CombineSetsAndCountFunction = { + new CombineSetsAndCountFunction(inputSet, this) + } } case class CombineSetsAndCountFunction( @@ -249,27 +253,31 @@ case class CombineSetsAndCountFunction( case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctPartitionFunction = { + new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + } } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctMergeFunction = { + new ApproxCountDistinctMergeFunction(child, this, relativeSD) + } } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def asPartial: SplitEvaluation = { val partialCount = @@ -280,14 +288,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) partialCount :: Nil) } - override def newInstance() = new CountDistinctFunction(child :: Nil, this) + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -296,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString = s"AVG($child)" + override def toString: String = s"AVG($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -323,14 +331,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } } - override def newInstance() = new AverageFunction(child, this) + override def newInstance(): AverageFunction = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -339,7 +347,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString = s"SUM($child)" + override def toString: String = s"SUM($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -357,7 +365,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } } - override def newInstance() = new SumFunction(child, this) + override def newInstance(): SumFunction = new SumFunction(child, this) } /** @@ -377,19 +385,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - override def children = child :: Nil - override def nullable = true - override def dataType = child.dataType - override def toString = s"CombineSum($child)" - override def newInstance() = new CombineSumFunction(child, this) + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } case class SumDistinct(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { def this() = this(null) - override def nullable = true - override def dataType = child.dataType match { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -397,10 +405,10 @@ case class SumDistinct(child: Expression) case _ => child.dataType } - override def toString = s"SUM(DISTINCT ${child})" - override def newInstance() = new SumDistinctFunction(child, this) + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() SplitEvaluation( CombineSetsAndSum(partialSet.toAttribute, this), @@ -411,11 +419,13 @@ case class SumDistinct(child: Expression) case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { def this() = this(null, null) - override def children = inputSet :: Nil - override def nullable = true - override def dataType = base.dataType - override def toString = s"CombineAndSum($inputSet)" - override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } } case class CombineSetsAndSumFunction( @@ -449,9 +459,9 @@ case class CombineSetsAndSumFunction( } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"FIRST($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" override def asPartial: SplitEvaluation = { val partialFirst = Alias(First(child), "PartialFirst")() @@ -459,14 +469,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance() = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references - override def nullable = true - override def dataType = child.dataType - override def toString = s"LAST($child)" + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" override def asPartial: SplitEvaluation = { val partialLast = Alias(Last(child), "PartialLast")() @@ -474,7 +484,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode Last(partialLast.toAttribute), partialLast :: Nil) } - override def newInstance() = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -713,6 +723,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg result = input } - override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row]) - else null + override def eval(input: Row): Any = { + if (result != null) expr.eval(result.asInstanceOf[Row]) else null + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 00b0d3c683fe2..1f6526ef66c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"-$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"-$child" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -47,10 +47,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = DoubleType - override def foldable = child.foldable - def nullable = true - override def toString = s"SQRT($child)" + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"SQRT($child)" lazy val numeric = child.dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -74,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + def nullable: Boolean = left.nullable || right.nullable override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - def dataType = { + def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -108,7 +108,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "+" + override def symbol: String = "+" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -131,7 +131,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "-" + override def symbol: String = "-" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -154,7 +154,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "*" + override def symbol: String = "*" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -177,9 +177,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "/" + override def symbol: String = "/" - override def nullable = true + override def nullable: Boolean = true lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div @@ -203,9 +203,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "%" + override def symbol: String = "%" - override def nullable = true + override def nullable: Boolean = true lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] @@ -232,7 +232,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise and(&) of two numbers. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "&" + override def symbol: String = "&" lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -253,7 +253,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * A function that calculates bitwise or(|) of two numbers. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "|" + override def symbol: String = "|" lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -274,7 +274,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise xor(^) of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "^" + override def symbol: String = "^" lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -297,10 +297,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"~$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"~$child" lazy val not: (Any) => Any = dataType match { case ByteType => @@ -327,17 +327,17 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def nullable = left.nullable && right.nullable + override def nullable: Boolean = left.nullable && right.nullable - override def children = left :: right :: Nil + override def children: Seq[Expression] = left :: right :: Nil override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType - override def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -366,7 +366,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } } - override def toString = s"MaxOf($left, $right)" + override def toString: String = s"MaxOf($left, $right)" } /** @@ -375,10 +375,10 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { case class Abs(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"Abs($child)" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"Abs($child)" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e48b8cde20eda..d1abf3c0b64a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -91,7 +91,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() - def timeMs = (endTime - startTime).toDouble / 1000000 + def timeMs: Double = (endTime - startTime).toDouble / 1000000 logInfo(s"Code generated expression $in in $timeMs ms") result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 68051a2a2007e..3fd78db297462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -27,12 +27,12 @@ import org.apache.spark.sql.types._ case class GetItem(child: Expression, ordinal: Expression) extends Expression { type EvaluatedType = Any - val children = child :: ordinal :: Nil + val children: Seq[Expression] = child :: ordinal :: Nil /** `Null` is returned for invalid ordinals. */ - override def nullable = true - override def foldable = child.foldable && ordinal.foldable + override def nullable: Boolean = true + override def foldable: Boolean = child.foldable && ordinal.foldable - def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt } @@ -40,7 +40,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { childrenResolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - override def toString = s"$child[$ordinal]" + override def toString: String = s"$child[$ordinal]" override def eval(input: Row): Any = { val value = child.eval(input) @@ -75,8 +75,8 @@ trait GetField extends UnaryExpression { self: Product => type EvaluatedType = Any - override def foldable = child.foldable - override def toString = s"$child.${field.name}" + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" def field: StructField } @@ -86,8 +86,8 @@ trait GetField extends UnaryExpression { */ case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { - def dataType = field.dataType - override def nullable = child.nullable || field.nullable + override def dataType: DataType = field.dataType + override def nullable: Boolean = child.nullable || field.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] @@ -101,8 +101,8 @@ case class StructGetField(child: Expression, field: StructField, ordinal: Int) e case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) extends GetField { - def dataType = ArrayType(field.dataType, containsNull) - override def nullable = child.nullable + override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[Row]] @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -140,5 +140,5 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString = s"Array(${children.mkString(",")})" + override def toString: String = s"Array(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 83d8c1d42bca4..adb94df7d1c7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override type EvaluatedType = Any override def dataType: DataType = LongType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"UnscaledValue($child)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"UnscaledValue($child)" override def eval(input: Row): Any = { val childResult = child.eval(input) @@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"MakeDecimal($child,$precision,$scale)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: Row): Decimal = { val childResult = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 0983d274def3f..860b72fad38b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -45,7 +45,7 @@ abstract class Generator extends Expression { override lazy val dataType = ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) - override def nullable = false + override def nullable: Boolean = false /** * Should be overridden by specific generators. Called only once for each instance to ensure @@ -89,7 +89,7 @@ case class UserDefinedGenerator( function(inputRow(input)) } - override def toString = s"UserDefinedGenerator(${children.mkString(",")})" + override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" } /** @@ -130,5 +130,5 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def toString() = s"explode($child)" + override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 9ff66563c8164..19f3fc9c2291a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -64,14 +64,13 @@ object IntegerLiteral { case class Literal(value: Any, dataType: DataType) extends LeafExpression { - override def foldable = true - def nullable = value == null + override def foldable: Boolean = true + override def nullable: Boolean = value == null - - override def toString = if (value != null) value.toString else "null" + override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row):Any = value + override def eval(input: Row): Any = value } // TODO: Specialize @@ -79,9 +78,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean extends LeafExpression { type EvaluatedType = Any - def update(expression: Expression, input: Row) = { + def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) } - override def eval(input: Row) = value + override def eval(input: Row): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 17f7f9fe51376..bcbcbeb31c7b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees.LeafNode import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId = ExprId(curId.getAndIncrement()) + def newExprId: ExprId = ExprId(curId.getAndIncrement()) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } @@ -41,6 +42,13 @@ abstract class NamedExpression extends Expression { def name: String def exprId: ExprId + /** + * Returns a dot separated fully qualified name for this attribute. Given that there can be + * multiple qualifiers, it is possible that there are other possible way to refer to this + * attribute. + */ + def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".") + /** * All possible qualifiers for the expression. * @@ -72,13 +80,13 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => - override def references = AttributeSet(this) + override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute - def toAttribute = this + def toAttribute: Attribute = this def newInstance(): Attribute } @@ -95,25 +103,30 @@ abstract class Attribute extends NamedExpression { * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. + * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. */ -case class Alias(child: Expression, name: String) - (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) +case class Alias(child: Expression, name: String)( + val exprId: ExprId = NamedExpression.newExprId, + val qualifiers: Seq[String] = Nil, + val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { override type EvaluatedType = Any - override def eval(input: Row) = child.eval(input) + override def eval(input: Row): Any = child.eval(input) - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable override def metadata: Metadata = { - child match { - case named: NamedExpression => named.metadata - case _ => Metadata.empty + explicitMetadata.getOrElse { + child match { + case named: NamedExpression => named.metadata + case _ => Metadata.empty + } } } - override def toAttribute = { + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { @@ -123,11 +136,14 @@ case class Alias(child: Expression, name: String) override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" - override protected final def otherCopyArgs = exprId :: qualifiers :: Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: explicitMetadata :: Nil + } override def equals(other: Any): Boolean = other match { case a: Alias => - name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers + name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers && + explicitMetadata == a.explicitMetadata case _ => false } } @@ -153,7 +169,7 @@ case class AttributeReference( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -167,7 +183,7 @@ case class AttributeReference( h } - override def newInstance() = + override def newInstance(): AttributeReference = AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** @@ -192,7 +208,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - override def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = { if (newQualifiers.toSet == qualifiers.toSet) { this } else { @@ -214,20 +230,22 @@ case class AttributeReference( case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { type EvaluatedType = Any - override def toString = name - - override def withNullability(newNullability: Boolean): Attribute = ??? - override def newInstance(): Attribute = ??? - override def withQualifiers(newQualifiers: Seq[String]): Attribute = ??? - override def withName(newName: String): Attribute = ??? - override def qualifiers: Seq[String] = ??? - override def exprId: ExprId = ??? - override def eval(input: Row): EvaluatedType = ??? - override def nullable: Boolean = ??? + override def toString: String = name + + override def withNullability(newNullability: Boolean): Attribute = + throw new UnsupportedOperationException + override def newInstance(): Attribute = throw new UnsupportedOperationException + override def withQualifiers(newQualifiers: Seq[String]): Attribute = + throw new UnsupportedOperationException + override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + override def exprId: ExprId = throw new UnsupportedOperationException + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } object VirtualColumn { - val groupingIdName = "grouping__id" - def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdName: String = "grouping__id" + def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 08b982bc671e7..d1f3d4f4ee9ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - def nullable = !children.exists(!_.nullable) + override def nullable: Boolean = !children.exists(!_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) // Only resolved if all the children are of the same type. override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) - override def toString = s"Coalesce(${children.mkString(",")})" + override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType = if (resolved) { + def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -54,20 +55,20 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false override def eval(input: Row): Any = { child.eval(input) == null } - override def toString = s"IS NULL $child" + override def toString: String = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false - override def toString = s"IS NOT NULL $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false + override def toString: String = s"IS NOT NULL $child" override def eval(input: Row): Any = { child.eval(input) != null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 0024ef92c0452..7e47cb3fffe12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{BinaryType, BooleanType, NativeType} +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -34,7 +34,7 @@ object InterpretedPredicate { trait Predicate extends Expression { self: Product => - def dataType = BooleanType + override def dataType: DataType = BooleanType type EvaluatedType = Any } @@ -72,13 +72,13 @@ trait PredicateHelper { abstract class BinaryPredicate extends BinaryExpression with Predicate { self: Product => - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable } case class Not(child: Expression) extends UnaryExpression with Predicate { - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"NOT $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"NOT $child" override def eval(input: Row): Any = { child.eval(input) match { @@ -92,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { * Evaluates to `true` if `list` contains `value`. */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { - def children = value +: list + override def children: Seq[Expression] = value +: list - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: Row): Any = { val evaluatedValue = value.eval(input) @@ -110,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case class InSet(value: Expression, hset: Set[Any]) extends Predicate { - def children = value :: Nil + override def children: Seq[Expression] = value :: Nil - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" override def eval(input: Row): Any = { hset.contains(value.eval(input)) @@ -121,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "&&" + override def symbol: String = "&&" override def eval(input: Row): Any = { val l = left.eval(input) @@ -143,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { } case class Or(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "||" + override def symbol: String = "||" override def eval(input: Row): Any = { val l = left.eval(input) @@ -169,7 +169,8 @@ abstract class BinaryComparison extends BinaryPredicate { } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "=" + override def symbol: String = "=" + override def eval(input: Row): Any = { val l = left.eval(input) if (l == null) { @@ -185,8 +186,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=>" - override def nullable = false + override def symbol: String = "<=>" + + override def nullable: Boolean = false + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -201,9 +204,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<" + override def symbol: String = "<" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -216,7 +219,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -230,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=" + override def symbol: String = "<=" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -245,7 +248,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -259,9 +262,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">" + override def symbol: String = ">" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -288,9 +291,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">=" + override def symbol: String = ">=" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -303,7 +306,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -317,13 +320,13 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends Expression { - def children = predicate :: trueValue :: falseValue :: Nil - override def nullable = trueValue.nullable || falseValue.nullable + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException( this, @@ -342,7 +345,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def toString = s"if ($predicate) $trueValue else $falseValue" + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } // scalastyle:off @@ -362,9 +365,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi // scalastyle:on case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any - def children = branches - def dataType = { + override def children: Seq[Expression] = branches + + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") } @@ -379,12 +383,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { @transient private[this] lazy val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - override def nullable = { + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } - override lazy val resolved = { + override lazy val resolved: Boolean = { if (!childrenResolved) { false } else { @@ -415,7 +419,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { res } - override def toString = { + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" case Seq(elseValue) => s" ELSE $elseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index f03d6f71a9fae..a8983df208318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,8 +44,8 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def toSeq = Seq.empty - override def length = 0 + override def toSeq: Seq[Any] = Seq.empty + override def length: Int = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException override def getLong(i: Int): Long = throw new UnsupportedOperationException @@ -56,7 +56,7 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - def copy() = this + override def copy(): Row = this } /** @@ -66,17 +66,17 @@ object EmptyRow extends Row { */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values.toSeq - override def length = values.length + override def length: Int = values.length - override def apply(i: Int) = values(i) + override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int) = values(i) == null + override def isNullAt(i: Int): Boolean = values(i) == null override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") @@ -167,16 +167,19 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { case _ => false } - def copy() = this + override def copy(): Row = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) extends GenericRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -194,7 +197,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } - override def copy() = new GenericRow(values.clone()) + override def copy(): Row = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 3a5bdca1f07c3..35faa00782e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -26,17 +26,17 @@ import org.apache.spark.util.collection.OpenHashSet case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def nullable = false + override def nullable: Boolean = false // We are currently only using these Expressions internally for aggregation. However, if we ever // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - def dataType = ArrayType(elementType) + override def dataType: DataType = ArrayType(elementType) - def eval(input: Row): Any = { + override def eval(input: Row): Any = { new OpenHashSet[Any]() } - override def toString = s"new Set($dataType)" + override def toString: String = s"new Set($dataType)" } /** @@ -46,12 +46,13 @@ case class NewSet(elementType: DataType) extends LeafExpression { case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any - def children = item :: set :: Nil + override def children: Seq[Expression] = item :: set :: Nil - def nullable = set.nullable + override def nullable: Boolean = set.nullable - def dataType = set.dataType - def eval(input: Row): Any = { + override def dataType: DataType = set.dataType + + override def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def toString = s"$set += $item" + override def toString: String = s"$set += $item" } /** @@ -77,13 +78,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable - def dataType = left.dataType + override def dataType: DataType = left.dataType - def symbol = "++=" + override def symbol: String = "++=" - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -109,16 +110,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def nullable = child.nullable + override def nullable: Boolean = child.nullable - def dataType = LongType + override def dataType: DataType = LongType - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong } } - override def toString = s"$child.count()" + override def toString: String = s"$child.count()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f85ee0a9bb6d8..3cdca4e9dd2d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -33,8 +33,8 @@ trait StringRegexExpression { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - def nullable: Boolean = left.nullable || right.nullable - def dataType: DataType = BooleanType + override def nullable: Boolean = left.nullable || right.nullable + override def dataType: DataType = BooleanType // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -98,11 +98,11 @@ trait CaseConversionExpression { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "LIKE" + override def symbol: String = "LIKE" // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character - override def escape(v: String) = + override def escape(v: String): String = if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" @@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "RLIKE" + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } @@ -141,7 +141,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toUpperCase() - override def toString() = s"Upper($child)" + override def toString: String = s"Upper($child)" } /** @@ -151,7 +151,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toLowerCase() - override def toString() = s"Lower($child)" + override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -160,7 +160,7 @@ trait StringComparison { type EvaluatedType = Any - def nullable: Boolean = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -175,9 +175,9 @@ trait StringComparison { } } - def symbol: String = nodeName + override def symbol: String = nodeName - override def toString() = s"$nodeName($left, $right)" + override def toString: String = s"$nodeName($left, $right)" } /** @@ -185,7 +185,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String) = l.contains(r) + override def compare(l: String, r: String): Boolean = l.contains(r) } /** @@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.startsWith(r) + override def compare(l: String, r: String): Boolean = l.startsWith(r) } /** @@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.endsWith(r) + override def compare(l: String, r: String): Boolean = l.endsWith(r) } /** @@ -212,17 +212,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends type EvaluatedType = Any - override def foldable = str.foldable && pos.foldable && len.foldable + override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - def nullable: Boolean = str.nullable || pos.nullable || len.nullable - def dataType: DataType = { + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") } if (str.dataType == BinaryType) str.dataType else StringType } - override def children = str :: pos :: len :: Nil + override def children: Seq[Expression] = str :: pos :: len :: Nil @inline def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) @@ -267,7 +267,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends } } - override def toString = len match { + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" case _ => s"SUBSTR($str, $pos, $len)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1a75fcf3545bd..c23d3b61887c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: Batch("Combine Limits", FixedPoint(100), CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), @@ -137,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] { condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b4c445b3badf1..9c8c643f7d17a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]) = fields.collect { + def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) - case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 17a88e07de15f..02f7c26a8ab6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} @@ -47,8 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. + * + * Note that virtual columns should be excluded. Currently, we only support the grouping ID + * virtual column. */ - def missingInput: AttributeSet = references -- inputSet + def missingInput: AttributeSet = + (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) /** * Runs [[transform]] with `rule` on all expressions present in this query operator. @@ -67,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionDown(e: Expression) = { + @inline def transformExpressionDown(e: Expression): Expression = { val newE = e.transformDown(rule) if (newE.fastEquals(e)) { e @@ -81,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) case other => other @@ -99,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionUp(e: Expression) = { + @inline def transformExpressionUp(e: Expression): Expression = { val newE = e.transformUp(rule) if (newE.fastEquals(e)) { e @@ -113,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) case other => other @@ -159,5 +165,5 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString = statePrefix + super.simpleString + override def simpleString: String = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8c4f09b58a4f2..b01a61d7bf8d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -73,12 +73,16 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * can do better should override this function. */ def sameResult(plan: LogicalPlan): Boolean = { - plan.getClass == this.getClass && - plan.children.size == children.size && { - logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") - cleanArgs == plan.cleanArgs + val cleanLeft = EliminateSubQueries(this) + val cleanRight = EliminateSubQueries(plan) + + cleanLeft.getClass == cleanRight.getClass && + cleanLeft.children.size == cleanRight.children.size && { + logDebug( + s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") + cleanRight.cleanArgs == cleanLeft.cleanArgs } && - (plan.children, children).zipped.forall(_ sameResult _) + (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) } /** Args that have cleaned such that differences in expression id should not affect equality */ @@ -208,8 +212,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // More than one match. case ambiguousReferences => + val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") throw new AnalysisException( - s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + s"Reference '$name' is ambiguous, could be: $referenceNames.") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 384fe53a68362..4d9e41a2b5d85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) override lazy val resolved: Boolean = { val containsAggregatesOrGenerators = projectList.exists ( _.collect { @@ -66,19 +66,19 @@ case class Generate( } } - override def output = + override def output: Seq[Attribute] = if (join) child.output ++ generatorOutput else generatorOutput } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output = left.output + override def output: Seq[Attribute] = left.output - override lazy val resolved = + override lazy val resolved: Boolean = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } @@ -94,7 +94,7 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftSemi => left.output @@ -109,7 +109,7 @@ case class Join( } } - def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty + private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguious expression ids. override lazy val resolved: Boolean = { @@ -118,7 +118,7 @@ case class Join( } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - def output = left.output + override def output: Seq[Attribute] = left.output } case class InsertIntoTable( @@ -128,10 +128,10 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) } @@ -143,14 +143,14 @@ case class CreateTableAsSelect[T]( child: LogicalPlan, allowExisting: Boolean, desc: Option[T] = None) extends UnaryNode { - override def output = Seq.empty[Attribute] - override lazy val resolved = databaseName != None && childrenResolved + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = databaseName != None && childrenResolved } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -163,7 +163,7 @@ case class Sort( order: Seq[SortOrder], global: Boolean, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Aggregate( @@ -172,7 +172,7 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } /** @@ -199,7 +199,7 @@ trait GroupingAnalytics extends UnaryNode { def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] - override def output = aggregations.map(_.toAttribute) + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) } /** @@ -264,7 +264,7 @@ case class Rollup( gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { val limit = limitExpr.eval(null).asInstanceOf[Int] @@ -274,21 +274,21 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Distinct(child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case object NoRelation extends LeafNode { - override def output = Nil + override def output: Seq[Attribute] = Nil /** * Computes [[Statistics]] for this plan. The default implementation assumes the output @@ -301,5 +301,5 @@ case object NoRelation extends LeafNode { } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 72b0c5c8e7a26..e737418d9c3bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} /** * Performs a physical redistribution of the data. Used when the consumer of the query @@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} abstract class RedistributeData extends UnaryNode { self: Product => - def output = child.output + override def output: Seq[Attribute] = child.output } case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) - extends RedistributeData { -} + extends RedistributeData case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData { -} - + extends RedistributeData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c3d7a3119064..288c11f69fe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "a single partition.") // TODO: This is not really valid... - def clustering = ordering.map(_.child).toSet + def clustering: Set[Expression] = ordering.map(_.child).toSet } sealed trait Partitioning { @@ -113,7 +113,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -124,7 +124,7 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -139,9 +139,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - override def children = expressions - override def nullable = false - override def dataType = IntegerType + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = expressions.toSet @@ -152,7 +152,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case h: HashPartitioning if h == this => true case _ => false @@ -178,9 +178,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - override def children = ordering - override def nullable = false - override def dataType = IntegerType + override def children: Seq[SortOrder] = ordering + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = ordering.map(_.child).toSet @@ -194,7 +194,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case r: RangePartitioning if r == this => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f84ffe4e176cc..a2df51e598a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.types.DataType /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -35,12 +36,12 @@ object CurrentOrigin { override def initialValue: Origin = Origin() } - def get = value.get() - def set(o: Origin) = value.set(o) + def get: Origin = value.get() + def set(o: Origin): Unit = value.set(o) - def reset() = value.set(Origin()) + def reset(): Unit = value.set(Origin()) - def setPosition(line: Int, start: Int) = { + def setPosition(line: Int, start: Int): Unit = { value.set( value.get.copy(line = Some(line), startPosition = Some(start))) } @@ -56,7 +57,7 @@ object CurrentOrigin { abstract class TreeNode[BaseType <: TreeNode[BaseType]] { self: BaseType with Product => - val origin = CurrentOrigin.get + val origin: Origin = CurrentOrigin.get /** Returns a Seq of the children of this node */ def children: Seq[BaseType] @@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) @@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) @@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + val defaultCtor = + getClass.getConstructors + .find(_.getParameterTypes.size != 0) + .headOption + .getOrElse(sys.error(s"No valid constructor for $nodeName")) + try { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. - val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] } else { @@ -320,18 +328,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " - + s"Exception message: ${e.getMessage}.") + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |args: ${newArgs.mkString(", ")} + """.stripMargin) } } /** Returns the name of this type of TreeNode. Defaults to the class name. */ - def nodeName = getClass.getSimpleName + def nodeName: String = getClass.getSimpleName /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. */ - protected def stringArgs = productIterator + protected def stringArgs: Iterator[Any] = productIterator /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { @@ -343,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { }.mkString(", ") /** String representation of this node without any children */ - def simpleString = s"$nodeName $argString".trim + def simpleString: String = s"$nodeName $argString".trim override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. */ - def numberedTreeString = + def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") /** @@ -406,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] { def left: BaseType def right: BaseType - def children = Seq(left, right) + def children: Seq[BaseType] = Seq(left, right) } /** * A [[TreeNode]] with no children. */ trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children = Nil + def children: Seq[BaseType] = Nil } /** @@ -421,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] { */ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType - def children = child :: Nil + def children: Seq[BaseType] = child :: Nil } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 79a8e06d4b4d4..ea6aa1850db4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -41,11 +41,11 @@ package object trees extends Logging { * A [[TreeNode]] companion for reference equality for Hash based Collection. */ class TreeNodeRef(val obj: TreeNode[_]) { - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case that: TreeNodeRef => that.obj.eq(obj) case _ => false } - override def hashCode = if (obj == null) 0 else obj.hashCode + override def hashCode: Int = if (obj == null) 0 else obj.hashCode } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index feed50f9a2a2d..c86214a2aa944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils package object util { - def fileToString(file: File, encoding: String = "UTF-8") = { + def fileToString(file: File, encoding: String = "UTF-8"): String = { val inStream = new FileInputStream(file) val outStream = new ByteArrayOutputStream try { @@ -45,7 +45,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = Utils.getSparkClassLoader) = { + classLoader: ClassLoader = Utils.getSparkClassLoader): String = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { @@ -93,7 +93,7 @@ package object util { new String(out.toByteArray) } - def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString def benchmark[A](f: => A): A = { val startTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index e50e9761431f5..6ee24ee0c1913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -41,6 +41,9 @@ import org.apache.spark.annotation.DeveloperApi sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index d973144de3468..952cf5c75688d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -670,6 +670,10 @@ case class PrecisionInfo(precision: Int, scale: Int) */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = Decimal.DecimalIsFractional @@ -819,6 +823,10 @@ object ArrayType { */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") @@ -857,6 +865,9 @@ case class StructField( nullable: Boolean = true, metadata: Metadata = Metadata.empty) { + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) @@ -1003,6 +1014,9 @@ object StructType { @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -1121,6 +1135,10 @@ case class MapType( keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") builder.append(s"$prefix-- value: ${valueType.typeName} " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c1dd5aa913ddc..756cd36f05c8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -32,9 +32,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseInsensitiveCatalog = new SimpleCatalog(false) val caseSensitiveAnalyzer = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } val caseInsensitiveAnalyzer = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } val checkAnalysis = new CheckAnalysis @@ -199,4 +203,22 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } + + test("SPARK-6452 regression test") { + // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + val plan = + Aggregate( + Nil, + Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + LocalRelation( + AttributeReference("a", StringType)(exprId = ExprId(2)))) + + assert(plan.resolved) + + val message = intercept[AnalysisException] { + caseSensitiveAnalyze(plan) + }.getMessage + + assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000000..f2f3a84d19380 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ec7d15f5bc4e7..3cd7adf8cab5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.language.implicitConversions import org.apache.spark.annotation.Experimental +import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField} @@ -46,7 +47,7 @@ private[sql] object Column { * @groupname Ungrouped Support functions for DataFrames. */ @Experimental -class Column(protected[sql] val expr: Expression) { +class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) @@ -109,7 +110,15 @@ class Column(protected[sql] val expr: Expression) { * * @group expr_ops */ - def === (other: Any): Column = EqualTo(expr, lit(other).expr) + def === (other: Any): Column = { + val right = lit(other).expr + if (this.expr == right) { + logWarning( + s"Constructing trivially true equals predicate, '${this.expr} = $right'. " + + "Perhaps you need to use aliases.") + } + EqualTo(expr, right) + } /** * Equality test. @@ -594,6 +603,19 @@ class Column(protected[sql] val expr: Expression) { */ def as(alias: Symbol): Column = Alias(expr, alias.name)() + /** + * Gives the column an alias with metadata. + * {{{ + * val metadata: Metadata = ... + * df.select($"colA".as("colB", metadata)) + * }}} + * + * @group expr_ops + */ + def as(alias: String, metadata: Metadata): Column = { + Alias(expr, alias)(explicitMetadata = Some(metadata)) + } + /** * Casts the column to a different data type. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5aece166aad22..4c80359cf07af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -751,6 +751,67 @@ class DataFrame private[sql]( select(colNames :_*) } + /** + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * df.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + * + * @group action + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + + // TODO: Add stddev as an expression, and remove it from here. + def stddevExpr(expr: Expression): Expression = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + + // The list of summary statistics to compute, in the form of expressions. + val statistics = List[(String, Expression => Expression)]( + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + } + + val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { + case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + } + } else { + // If there are no output columns, just output a single column that contains the stats. + statistics.map { case (name, _) => Row(name) } + } + + // The first column is string type, and the rest are double type. + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + LocalRelation(schema, ret) + } + /** * Returns the first `n` rows. * @group action diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dc9912b52dcab..e59cf9b9e037b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { + val (dataType, _) = inferDataType(beanClass) + dataType.asInstanceOf[StructType].fields.map { f => + AttributeReference(f.name, f.dataType, f.nullable)() + } + } + + /** + * Infers the corresponding SQL data type of a Java class. + * @param clazz Java class + * @return (SQL data type, nullable) + */ + private def inferDataType(clazz: Class[_]): (DataType, Boolean) = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. - val beanInfo = Introspector.getBeanInfo(beanClass) - - // Note: The ordering of elements may differ from when the schema is inferred in Scala. - // This is because beanInfo.getPropertyDescriptors gives no guarantees about - // element ordering. - val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - fields.map { property => - val (dataType, nullable) = property.getPropertyType match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - } - AttributeReference(property.getName, dataType, nullable)() + clazz match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + + case c: Class[_] if c.isArray => + val (dataType, nullable) = inferDataType(c.getComponentType) + (ArrayType(dataType, nullable), true) + + case _ => + val beanInfo = Introspector.getBeanInfo(clazz) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val (dataType, nullable) = inferDataType(property.getPropertyType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index c4534fd5f67e4..967bd76b302d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHa private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { - val kryo = new Kryo() + val kryo = super.newKryo() kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) @@ -57,8 +57,6 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[Decimal]) kryo.setReferences(false) - kryo.setClassLoader(Utils.getSparkClassLoader) - new AllScalaRegistrar().apply(kryo) kryo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 20c9bc3e75542..1f5251a20376f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.MutablePair +import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter /** @@ -194,7 +194,9 @@ case class ExternalSort( val ordering = newOrdering(sortOrder, child.output) val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r, null))) - sorter.iterator.map(_._1) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 76f8593180e85..463e1dcc268bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.util.Properties import org.apache.commons.lang.StringEscapeUtils.escapeSql import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} @@ -90,9 +91,9 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table specification is garbage. * @throws SQLException if the table contains an unsupported type. */ - def resolveTable(url: String, table: String): StructType = { + def resolveTable(url: String, table: String, properties: Properties): StructType = { val quirks = DriverQuirks.get(url) - val conn: Connection = DriverManager.getConnection(url) + val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() try { @@ -147,7 +148,7 @@ private[sql] object JDBCRDD extends Logging { * * @return A function that loads the driver and connects to the url. */ - def getConnector(driver: String, url: String): () => Connection = { + def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { if (driver != null) Class.forName(driver) @@ -156,7 +157,7 @@ private[sql] object JDBCRDD extends Logging { logWarning(s"Couldn't find class $driver", e); } } - DriverManager.getConnection(url) + DriverManager.getConnection(url, properties) } } /** @@ -179,6 +180,7 @@ private[sql] object JDBCRDD extends Logging { schema: StructType, driver: String, url: String, + properties: Properties, fqTable: String, requiredColumns: Array[String], filters: Array[Filter], @@ -189,7 +191,7 @@ private[sql] object JDBCRDD extends Logging { return new JDBCRDD( sc, - getConnector(driver, url), + getConnector(driver, url, properties), prunedSchema, fqTable, requiredColumns, @@ -361,7 +363,7 @@ private[sql] class JDBCRDD( var ans = 0L var j = 0 while (j < bytes.size) { - ans = 256*ans + (255 & bytes(j)) + ans = 256 * ans + (255 & bytes(j)) j = j + 1; } mutableRow.setLong(i, ans) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index df687e6da9bea..4fa84dc076f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.types.StructType +import java.sql.DriverManager +import java.util.Properties import scala.collection.mutable.ArrayBuffer -import java.sql.DriverManager import org.apache.spark.Partition +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType /** * Data corresponding to one partition of a JDBCRDD. @@ -115,18 +116,21 @@ private[sql] class DefaultSource extends RelationProvider { numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(url, table, parts)(sqlContext) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) } } private[sql] case class JDBCRelation( url: String, table: String, - parts: Array[Partition])(@transient val sqlContext: SQLContext) + parts: Array[Partition], + properties: Properties = new Properties())(@transient val sqlContext: SQLContext) extends BaseRelation with PrunedFilteredScan { - override val schema: StructType = JDBCRDD.resolveTable(url, table) + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName @@ -135,6 +139,7 @@ private[sql] case class JDBCRelation( schema, driver, url, + properties, table, requiredColumns, filters, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index f898e4b37a56b..43ca359b51735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -127,6 +127,12 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case DateType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addInt(value: Int): Unit = + parent.updateDate(fieldIndex, value.asInstanceOf[DateType.JvmType]) + } + } case d: DecimalType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addBinary(value: Binary): Unit = @@ -192,6 +198,9 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = updateField(fieldIndex, value) + protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + updateField(fieldIndex, value) + protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = updateField(fieldIndex, value) @@ -388,6 +397,9 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = current.setInt(fieldIndex, value) + override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = + current.update(fieldIndex, value) + override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 19bfba34b8f4a..5a1b15490d273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -212,6 +212,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case DateType => writer.addInteger(value.asInstanceOf[Int]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -358,6 +359,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case DateType => writer.addInteger(record.getInt(index)) case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 5209581fa8357..da668f068613b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -64,6 +64,8 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 + if originalType == ParquetOriginalType.DATE => DateType case ParquetPrimitiveTypeName.INT32 => IntegerType case ParquetPrimitiveTypeName.INT64 => LongType case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType @@ -222,6 +224,8 @@ private[parquet] object ParquetTypesConverter extends Logging { // There is no type for Byte or Short so we promote them to INT32. case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case DateType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) case DecimalType.Fixed(precision, scale) if precision <= 18 => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 410600b0529d3..0d68810ec6043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -435,11 +435,18 @@ private[sql] case class ParquetRelation2( // Push down filters when possible. Notice that not all filters can be converted to Parquet // filter predicate. Here we try to convert each individual predicate and only collect those // convertible ones. - predicates - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + if (sqlContext.conf.parquetFilterPushDown) { + predicates + // Don't push down predicates which reference partition columns + .filter { pred => + val partitionColNames = partitionColumns.map(_.name).toSet + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } + .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + } if (isPartitioned) { logInfo { @@ -758,12 +765,13 @@ private[sql] object ParquetRelation2 extends Logging { |${parquetSchema.prettyJson} """.stripMargin - assert(metastoreSchema.size == parquetSchema.size, schemaConflictMessage) + assert(metastoreSchema.size <= parquetSchema.size, schemaConflictMessage) val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = parquetSchema.sortBy(f => ordinalMap(f.name.toLowerCase)) + val reorderedParquetSchema = parquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { // Uses Parquet field names but retains Metastore data types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index d2e807d3a69b6..eb46b46ca5bf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -204,19 +204,25 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) + def className = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case dataSource: org.apache.spark.sql.sources.RelationProvider => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } case None => clazz.newInstance() match { case dataSource: RelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } } new ResolvedDataSource(clazz, relation) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 2d586f784ac5a..1ff2d5a190521 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,29 +17,39 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; -import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { + private transient JavaSparkContext jsc; private transient SQLContext context; @Before public void setUp() { // Trigger static initializer of TestData TestData$.MODULE$.testData(); + jsc = new JavaSparkContext(TestSQLContext.sparkContext()); context = TestSQLContext$.MODULE$; } @After public void tearDown() { + jsc = null; context = null; } @@ -90,4 +100,33 @@ public void testShow() { df.show(); df.show(1000); } + + public static class Bean implements Serializable { + private double a = 0.0; + private Integer[] b = new Integer[]{0, 1}; + + public double getA() { + return a; + } + + public Integer[] getB() { + return b; + } + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + StructType schema = df.schema(); + Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), + schema.apply("a")); + Assert.assertEquals( + new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), + schema.apply("b")); + Row first = df.select("a", "b").first(); + Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); + Assert.assertArrayEquals(bean.getB(), first.getAs(1)); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a53ae97d6243a..bc8fae100db6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.catalyst.plans.logical.{Project, NoRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { @@ -322,4 +320,15 @@ class ColumnExpressionSuite extends QueryTest { assert('key.desc == 'key.desc) assert('key.desc != 'key.asc) } + + test("alias with metadata") { + val metadata = new MetadataBuilder() + .putString("originName", "value") + .build() + val schema = testData + .select($"*", col("value").as("abc", metadata)) + .schema + assert(schema("value").metadata === Metadata.empty) + assert(schema("abc").metadata === metadata) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ff441ef26f9c0..fbc4065a9666c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -108,6 +108,13 @@ class DataFrameSuite extends QueryTest { ) } + test("self join with aliases") { + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + test("explode") { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = @@ -436,6 +443,50 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("describe") { + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index dd0948ad824be..e4dee87849fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -34,7 +34,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -109,7 +109,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468dad..36465cc2fa11a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.SparkSqlSerializer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends FunSuite { @@ -50,4 +53,13 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index be105c6e83594..d615542ab50a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -50,4 +50,10 @@ class UDFSuite extends QueryTest { .select($"ret.f1").head().getString(0) assert(result === "test") } + + test("udf that is transformed") { + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5eb6ab2e92e8b..592ed4b23b7d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -19,22 +19,31 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal import java.sql.DriverManager -import java.util.{Calendar, GregorianCalendar} +import java.util.{Calendar, GregorianCalendar, Properties} import org.apache.spark.sql.test._ +import org.h2.jdbc.JdbcSQLException import org.scalatest.{FunSuite, BeforeAndAfter} import TestSQLContext._ import TestSQLContext.implicits._ class JDBCSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) before { Class.forName("org.h2.Driver") - conn = DriverManager.getConnection(url) + // Extra properties that will be specified for our database. We need these to test + // usage of parameters from OPTIONS clause in queries. + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() @@ -46,15 +55,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE foobar |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) sql( s""" |CREATE TEMPORARY TABLE parts |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', - |partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " @@ -68,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE inttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.INTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), " + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate() - var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") + val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") stmt.setBytes(1, testBytes) stmt.setString(2, "Sensitive") stmt.setString(3, "Insensitive") @@ -85,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE strtypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.STRTYPES') + |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" @@ -97,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE timetypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES') + |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) @@ -112,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { s""" |CREATE TEMPORARY TABLE flttypes |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES') + |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. @@ -174,16 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.jdbc(url, "TEST.PEOPLE").collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3) } test("Partitioning via JDBCPartitioningInfo API") { - assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", "THEID", 0, 4, 3).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) + .collect.size == 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", parts).collect.size == 3) + assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3) } test("H2 integral types") { @@ -216,7 +226,6 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(rows(0).getString(5).equals("I am a clob!")) } - test("H2 time types") { val rows = sql("SELECT * FROM timetypes").collect() val cal = new GregorianCalendar(java.util.Locale.ROOT) @@ -246,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { .equals(new BigDecimal("123456789012345.54321543215432100000"))) } - test("SQL query as table name") { sql( s""" |CREATE TEMPORARY TABLE hack |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)') + |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)', + | user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) val rows = sql("SELECT * FROM hack").collect() assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. // For some reason, H2 computes this square incorrectly... assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } + + test("Pass extra properties via OPTIONS") { + // We set rowId to false during setup, which means that _ROWID_ column should be absent from + // all tables. If rowId is true (default), the query below doesn't throw an exception. + intercept[JdbcSQLException] { + sql( + s""" + |CREATE TEMPORARY TABLE abc + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4d32e84fc1115..6a2c2a7c4080a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -321,6 +321,23 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.parquetFile(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 5438095addeaf..203bc79f153dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -135,6 +135,21 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } + test("date type") { + def makeDateRDD(): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(DateUtils.toJavaDate(i))) + .toDF() + .select($"_1") + + withTempPath { dir => + val data = makeDateRDD() + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + } + } + test("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index ad880e2bc3679..8462f9bb2d620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -57,7 +57,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Byte, Short, Int, Long)]( + testSchema[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -65,6 +65,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { | required int32 _2 (INT_16); | required int32 _3 (INT_32); | required int64 _4 (INT_64); + | optional int32 _5 (DATE); |} """.stripMargin) @@ -211,8 +212,11 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("UPPERCase", IntegerType, nullable = true)))) } - // Conflicting field count - assert(intercept[Throwable] { + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -220,6 +224,17 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructType(Seq( StructField("lowerCase", BinaryType), StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Conflicting field count + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) // Conflicting field names diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4c5eb48661f7d..d1a99555e90c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -459,7 +459,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) // Write path case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) @@ -470,7 +470,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) @@ -479,33 +479,35 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) } + // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and + // their output attributes as the key of the map. This is because MetastoreRelation.equals + // doesn't take output attributes into account, thus multiple MetastoreRelation instances + // pointing to the same table get collapsed into a single entry in the map. A proper fix for + // this should be overriding equals & hashCode in MetastoreRelation. val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes // attribute IDs referenced in other nodes. plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r) => { - val parquetRelation = relationMap(r) - val withAlias = - r.alias.map(a => Subquery(a, parquetRelation)).getOrElse( - Subquery(r.tableName, parquetRelation)) + case r: MetastoreRelation if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) + val alias = r.alias.getOrElse(r.tableName) + Subquery(alias, parquetRelation) - withAlias - } case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r) => { - val parquetRelation = relationMap(r) + if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) InsertIntoTable(parquetRelation, partition, child, overwrite) - } + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r) => { - val parquetRelation = relationMap(r) + if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) InsertIntoTable(parquetRelation, partition, child, overwrite) - } + case other => other.transformExpressions { case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 51775eb4cd6a0..c45c4ad70fae9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -55,37 +55,8 @@ private[hive] case object NativePlaceholder extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( - "TOK_DESCFUNCTION", - "TOK_DESCDATABASE", - "TOK_SHOW_CREATETABLE", - "TOK_SHOWCOLUMNS", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWINDEXES", - "TOK_SHOWPARTITIONS", - "TOK_SHOW_TBLPROPERTIES", - - "TOK_LOCKTABLE", - "TOK_SHOWLOCKS", - "TOK_UNLOCKTABLE", - - "TOK_SHOW_ROLES", - "TOK_CREATEROLE", - "TOK_DROPROLE", - "TOK_GRANT", - "TOK_GRANT_ROLE", - "TOK_REVOKE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - "TOK_SHOW_SET_ROLE", - - "TOK_CREATEFUNCTION", - "TOK_DROPFUNCTION", - - "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERDATABASE_OWNER", + "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", "TOK_ALTERTABLE_ADDCOLS", @@ -102,28 +73,61 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_DROPDATABASE", - "TOK_DROPINDEX", - "TOK_DROPTABLE_PROPERTIES", - "TOK_MSCK", - "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", + + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_CREATEROLE", "TOK_CREATEVIEW", - "TOK_DROPVIEW_PROPERTIES", + + "TOK_DESCDATABASE", + "TOK_DESCFUNCTION", + + "TOK_DROPDATABASE", + "TOK_DROPFUNCTION", + "TOK_DROPINDEX", + "TOK_DROPROLE", + "TOK_DROPTABLE_PROPERTIES", "TOK_DROPVIEW", - + "TOK_DROPVIEW_PROPERTIES", + "TOK_EXPORT", + + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_IMPORT", + "TOK_LOAD", - - "TOK_SWITCHDATABASE" + + "TOK_LOCKTABLE", + + "TOK_MSCK", + + "TOK_REVOKE", + + "TOK_SHOW_CREATETABLE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLES", + "TOK_SHOW_SET_ROLE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOWCOLUMNS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWLOCKS", + "TOK_SHOWPARTITIONS", + + "TOK_SWITCHDATABASE", + + "TOK_UNLOCKTABLE" ) // Commands that we do not need to explain. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index af309c0c6ce2c..3563472c7ae81 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -116,7 +116,7 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } deserializedHadoopRDD @@ -189,9 +189,13 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) + // get the table deserializer + val tableSerDe = tableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, tableDesc.getProperties) // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, + mutableRow, tableSerDe) } }.toSeq @@ -261,25 +265,36 @@ private[hive] object HadoopTableReader extends HiveInspectors { * Transform all given raw `Writable`s into `Row`s. * * @param iterator Iterator of all `Writable`s to be transformed - * @param deserializer The `Deserializer` associated with the input `Writable` + * @param rawDeser The `Deserializer` associated with the input `Writable` * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled + * @param tableDeser Table Deserializer * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( iterator: Iterator[Writable], - deserializer: Deserializer, + rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow, + tableDeser: Deserializer): Iterator[Row] = { + + val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { + rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] + } else { + HiveShim.getConvertedOI( + rawDeser.getObjectInspector, + tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] + } - val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal }.unzip - // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern - // matching and branching costs per row. + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => @@ -316,9 +331,11 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } + val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) + // Map each tuple to a row object iterator.map { value => - val raw = deserializer.deserialize(value) + val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bfe43373d9534..47305571e579e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -375,9 +375,8 @@ private[hive] case class HiveUdafFunction( private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - // Cast required to avoid type inference selecting a deprecated Hive API. private val buffer = - function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + function.getNewAggregationBuffer override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index dc61e9d2e3522..a3497eadd67f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -23,6 +23,7 @@ import java.util.{Set => JavaSet} import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -153,8 +154,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r + val vs = new VariableSubstitution() + + // we should substitute variables in hql to pass the text to parseSql() as a parameter. + // Hive parser need substituted text. HiveContext.sql() does this but return a DataFrame, + // while we need a logicalPlan so we cannot reuse that. protected[hive] class HiveQLQueryExecution(hql: String) - extends this.QueryExecution(HiveQl.parseSql(hql)) { + extends this.QueryExecution(HiveQl.parseSql(vs.substitute(hiveconf, hql))) { def hiveExec(): Seq[String] = runSqlHive(hql) override def toString: String = hql + "\n" + super.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 44d24273e722a..221a0c263d36c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -92,12 +92,12 @@ class CachedTableSuite extends QueryTest { } test("Drop cached table") { - sql("CREATE TABLE test(a INT)") - cacheTable("test") - sql("SELECT * FROM test").collect() - sql("DROP TABLE test") + sql("CREATE TABLE cachedTableTest(a INT)") + cacheTable("cachedTableTest") + sql("SELECT * FROM cachedTableTest").collect() + sql("DROP TABLE cachedTableTest") intercept[AnalysisException] { - sql("SELECT * FROM test").collect() + sql("SELECT * FROM cachedTableTest").collect() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index f04437c595bf6..968557c9c4686 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -19,12 +19,29 @@ package org.apache.spark.sql.hive import java.io.{OutputStream, PrintStream} +import scala.util.Try + +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.{AnalysisException, QueryTest} -import scala.util.Try -class ErrorPositionSuite extends QueryTest { +class ErrorPositionSuite extends QueryTest with BeforeAndAfter { + + before { + Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") + } + + positionTest("ambiguous attribute reference 1", + "SELECT a from dupAttributes", "a") + + positionTest("ambiguous attribute reference 2", + "SELECT a, b from dupAttributes", "a") + + positionTest("ambiguous attribute reference 3", + "SELECT b, a from dupAttributes", "a") positionTest("unresolved attribute 1", "SELECT x FROM src", "x") @@ -127,6 +144,10 @@ class ErrorPositionSuite extends QueryTest { val error = intercept[AnalysisException] { quietly(sql(query)) } + + assert(!error.getMessage.contains("Seq(")) + assert(!error.getMessage.contains("List(")) + val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect { case (l, i) if l.contains(token) => (l, i + 1) }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 381cd2a29123e..8011952e0d535 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -32,9 +32,12 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) +case class ThreeCloumntable(key: Int, value: String, key1: String) + class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ + val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -186,4 +189,43 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("DROP TABLE hiveTableWithStructValue") } + + test("SPARK-5498:partition schema does not match table schema") { + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() + + val tmpDir = Utils.createTempDir() + sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + + // test schema the same between partition and table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + // test difference type of field + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + // add column to table + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + testDatawithNull.collect.toSeq + ) + + // change column name to table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + sql("DROP TABLE table_with_partition") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index ff2e6ea9ea51d..e5ad0bf552073 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -579,7 +579,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row(3) :: Row(4) :: Nil ) - table("test_parquet_ctas").queryExecution.analyzed match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index cb405f56bf53d..d7c5d1a25a82b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -22,7 +22,7 @@ import java.util import java.util.Properties import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -93,6 +93,15 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index d891c4e8903d9..432d65a874518 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -292,7 +292,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { Seq(Row(1, "str1")) ) - table("test_parquet_ctas").queryExecution.analyzed match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( @@ -365,6 +365,31 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE IF EXISTS test_insert_parquet") } + + test("SPARK-6450 regression test") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed + + assertResult(2) { + analyzed.collect { + case r @ LogicalRelation(_: ParquetRelation2) => r + }.size + } + + sql("DROP TABLE ms_convert") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 30646ddbc29d8..0ed93c2c5b1fa 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} @@ -210,7 +210,7 @@ private[hive] object HiveShim { def getDataLocationPath(p: Partition) = p.getPartitionPath - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) + def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) def compatibilityBlackList = Seq( "decimal_.*", @@ -244,6 +244,12 @@ private[hive] object HiveShim { } } + def getConvertedOI( + inputOI: ObjectInspector, + outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) + } + def prepareWritable(w: Writable): Writable = { w } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index f9fcbdae15745..7577309900209 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.util import java.util.{ArrayList => JArrayList} import java.util.Properties import java.rmi.server.UID @@ -38,7 +39,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable @@ -400,7 +401,11 @@ private[hive] object HiveShim { Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) } } - + + def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) + } + /* * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index db64e11e16304..f73b463d07779 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -67,12 +67,12 @@ object Checkpoint extends Logging { val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r /** Get the checkpoint file for the given checkpoint time */ - def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) } /** Get the checkpoint backup file for the given checkpoint time */ - def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") } @@ -232,6 +232,8 @@ object CheckpointReader extends Logging { def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) + + // TODO(rxin): Why is this a def?! def fs = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 0e285d6088ec1..175140481e5ae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -100,11 +100,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray } - def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray } - def getReceiverInputStreams() = this.synchronized { + def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized { inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]]) .map(_.asInstanceOf[ReceiverInputDStream[_]]) .toArray diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index a0d8fb5ab93ec..3249bb348981f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -55,7 +55,6 @@ case class Duration (private val millis: Long) { def div(that: Duration): Double = this / that - def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -71,7 +70,7 @@ case class Duration (private val millis: Long) { def milliseconds: Long = millis - def prettyPrint = Utils.msDurationToString(millis) + def prettyPrint: String = Utils.msDurationToString(millis) } @@ -80,7 +79,7 @@ case class Duration (private val millis: Long) { * a given number of milliseconds. */ object Milliseconds { - def apply(milliseconds: Long) = new Duration(milliseconds) + def apply(milliseconds: Long): Duration = new Duration(milliseconds) } /** @@ -88,7 +87,7 @@ object Milliseconds { * a given number of seconds. */ object Seconds { - def apply(seconds: Long) = new Duration(seconds * 1000) + def apply(seconds: Long): Duration = new Duration(seconds * 1000) } /** @@ -96,7 +95,7 @@ object Seconds { * a given number of minutes. */ object Minutes { - def apply(minutes: Long) = new Duration(minutes * 60000) + def apply(minutes: Long): Duration = new Duration(minutes * 60000) } // Java-friendlier versions of the objects above. @@ -107,16 +106,16 @@ object Durations { /** * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. */ - def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. */ - def seconds(seconds: Long) = Seconds(seconds) + def seconds(seconds: Long): Duration = Seconds(seconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. */ - def minutes(minutes: Long) = Minutes(minutes) + def minutes(minutes: Long): Duration = Minutes(minutes) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index ad4f3fdd14ad6..3f5be785e1b1a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) { this.endTime < that.endTime } - def <= (that: Interval) = (this < that || this == that) + def <= (that: Interval): Boolean = (this < that || this == that) - def > (that: Interval) = !(this <= that) + def > (that: Interval): Boolean = !(this <= that) - def >= (that: Interval) = !(this < that) + def >= (that: Interval): Boolean = !(this < that) - override def toString = "[" + beginTime + ", " + endTime + "]" + override def toString: String = "[" + beginTime + ", " + endTime + "]" } private[streaming] object Interval { - def currentInterval(duration: Duration): Interval = { + def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) val intervalBegin = time.floor(duration) new Interval(intervalBegin, intervalBegin + duration) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 543224d4b07bc..f57f295874645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -188,7 +188,7 @@ class StreamingContext private[streaming] ( /** * Return the associated Spark context */ - def sparkContext = sc + def sparkContext: SparkContext = sc /** * Set each DStreams in this context to remember RDDs it generated in the last given duration. @@ -596,7 +596,8 @@ object StreamingContext extends Logging { @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) + : PairDStreamFunctions[K, V] = { DStream.toPairDStreamFunctions(stream)(kt, vt, ord) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 2eabdd9387913..73030e15c5661 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -415,8 +415,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T implicit val cmv2: ClassTag[V2] = fakeClassTag implicit val cmw: ClassTag[W] = fakeClassTag - def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = + def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = { transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd + } dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 7053f47ec69a2..4c28654ef6413 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -176,11 +176,11 @@ private[python] abstract class PythonDStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -212,7 +212,7 @@ private[python] class PythonTransformed2DStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent, parent2) + override def dependencies: List[DStream[_]] = List(parent, parent2) override def slideDuration: Duration = parent.slideDuration @@ -223,7 +223,7 @@ private[python] class PythonTransformed2DStream( func(Some(rdd1), Some(rdd2), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -260,12 +260,15 @@ private[python] class PythonReducedWindowedDStream( extends PythonDStream(parent, preduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) - override val mustCheckpoint = true - val invReduceFunc = new TransformFunction(pinvReduceFunc) + override val mustCheckpoint: Boolean = true + + val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index b874f561c12eb..795c5aa6d585b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -104,7 +104,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ - def context = ssc + def context: StreamingContext = ssc /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() @@ -619,14 +619,16 @@ abstract class DStream[T: ClassTag] ( * operator, so this DStream will be registered as an output stream and there materialized. */ def print(num: Int) { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val firstNum = rdd.take(num + 1) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - firstNum.take(num).foreach(println) - if (firstNum.size > num) println("...") - println() + def foreachFunc: (RDD[T], Time) => Unit = { + (rdd: RDD[T], time: Time) => { + val firstNum = rdd.take(num + 1) + println("-------------------------------------------") + println("Time: " + time) + println("-------------------------------------------") + firstNum.take(num).foreach(println) + if (firstNum.size > num) println("...") + println() + } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 0dc72790fbdbd..39fd21342813e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } } - override def toString() = { + override def toString: String = { "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 22de8c02e63c8..66d519171fd76 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -298,7 +298,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] + private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() @@ -320,7 +320,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( } } - override def toString() = { + override def toString: String = { "[\n" + hadoopFiles.size + " file sets\n" + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index c81534ae584ea..fcd5216f101af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag]( filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 658623455498c..9d09a3baf37ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( flatMapValueFunc: V => TraversableOnce[U] ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index c7bb2833eabb8..475ea2d2d4f38 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag]( flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 1361c30395b57..685a32e1d280d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] ( foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index a9bb51f054048..dbb295fe54f71 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -25,7 +25,7 @@ private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index aa1993f0580a8..e652702e213ef 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -61,7 +61,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) } } - override def dependencies = List() + override def dependencies: List[DStream[_]] = List() override def slideDuration: Duration = { if (ssc == null) throw new Exception("ssc is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 3d8ee29df1e82..5994bc1e23f2b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag]( preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 7aea1f945d9db..954d2eb4a7b00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( mapValueFunc: V => U ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index 02704a8d1c2e0..fa14b2e897c3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] ( mapFunc: T => U ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index c0a5af0b65cc3..1385ccbf56ee5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -52,7 +52,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Reduce each batch of data using reduceByKey which will be further reduced by window // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + private val reducedStream = parent.reduceByKey(reduceFunc, partitioner) // Persist RDDs to memory by default as these RDDs are going to be reused. super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -60,7 +60,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(reducedStream) + override def dependencies: List[DStream[_]] = List(reducedStream) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 880a89bc36895..7757ccac09a58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -33,7 +33,7 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( mapSideCombine: Boolean = true ) extends DStream[(K,C)] (parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index ebb04dd35b9a2..de8718d0a80fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 71b61856e23c0..5d46ca0715ffd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] ( require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index abbc40befa95b..9405dbaa12329 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { + parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) + } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 775b6bfd065c0..899865a906c27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -46,7 +46,7 @@ class WindowedDStream[T: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index dd1e96334952f..93caa4ba35c7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -117,8 +117,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - def segmentLocations = HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) - blockLocations.getOrElse(segmentLocations) + blockLocations.getOrElse( + HdfsUtils.getFileSegmentLocations( + partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index a7d63bd4f2dbf..cd309788a7717 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.receiver +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ @@ -25,10 +26,10 @@ import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} + import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.StorageLevel -import java.nio.ByteBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -149,13 +150,13 @@ private[streaming] class ActorReceiver[T: ClassTag]( class Supervisor extends Actor { override val supervisorStrategy = receiverSupervisorStrategy - val worker = context.actorOf(props, name) + private val worker = context.actorOf(props, name) logInfo("Started receiver worker at:" + worker.path) - val n: AtomicInteger = new AtomicInteger(0) - val hiccups: AtomicInteger = new AtomicInteger(0) + private val n: AtomicInteger = new AtomicInteger(0) + private val hiccups: AtomicInteger = new AtomicInteger(0) - def receive = { + override def receive: PartialFunction[Any, Unit] = { case IteratorData(iterator) => logDebug("received iterator") @@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag]( } } - def onStart() = { + def onStart(): Unit = { supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) - } - def onStop() = { + def onStop(): Unit = { supervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index ee5e639b26d91..42514d8b47dcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -120,7 +120,7 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. All received data items * will be periodically pushed into BlockManager. */ - def addDataWithCallback(data: Any, metadata: Any) = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { waitToPush() currentBuffer += data listener.onAddData(data, metadata) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5acf8a9a811ee..5b5a3fe648602 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -245,7 +245,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * Get the unique identifier the receiver input stream that this * receiver is associated with. */ - def streamId = id + def streamId: Int = id /* * ================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 1f0244c251eba..4943f29395d12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -162,13 +162,13 @@ private[streaming] abstract class ReceiverSupervisor( } /** Check if receiver has been marked for stopping */ - def isReceiverStarted() = { + def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started } /** Check if receiver has been marked for stopping */ - def isReceiverStopped() = { + def isReceiverStopped(): Boolean = { logDebug("state = " + receiverState) receiverState == Stopped } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 7d29ed88cfcb4..8f2f1fef76874 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await -import akka.actor.{Actor, Props} +import akka.actor.{ActorRef, Actor, Props} import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration @@ -83,7 +83,7 @@ private[streaming] class ReceiverSupervisorImpl( private val actor = env.actorSystem.actorOf( Props(new Actor { - override def receive() = { + override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) @@ -92,7 +92,7 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) } - def ref = self + def ref: ActorRef = self }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) /** Unique block ids if one wants to add blocks directly */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 7e0f6b2cdfc08..30cf87f5b7dd1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -36,5 +36,5 @@ class Job(val time: Time, func: () => _) { id = "streaming job " + time + "." + number } - override def toString = id + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 59488dfb0f8c6..4946806d2ee95 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -82,7 +82,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case event: JobGeneratorEvent => processEvent(event) } }), "JobGenerator") @@ -111,8 +111,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val pollTime = 100 // To prevent graceful stop to get stuck permanently - def hasTimedOut = { - val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + def hasTimedOut: Boolean = { + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout if (timedOut) { logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") } @@ -133,7 +133,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Stopped generation timer") // Wait for the jobs to complete and checkpoints to be written - def haveAllBatchesBeenProcessed = { + def haveAllBatchesBeenProcessed: Boolean = { lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime } logInfo("Waiting for jobs to be processed and checkpoints to be written") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 60bc099b27a4c..d6a93acbe711b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -56,7 +56,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { logDebug("Starting JobScheduler") eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case event: JobSchedulerEvent => processEvent(event) } }), "JobScheduler") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 8c15a75b1b0e0..5b134877d0b2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,8 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty - ) { + receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -48,17 +47,17 @@ case class JobSet( if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted = processingStartTime > 0 + def hasStarted: Boolean = processingStartTime > 0 - def hasCompleted = incompleteJobs.isEmpty + def hasCompleted: Boolean = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) - def processingDelay = processingEndTime - processingStartTime + def processingDelay: Long = processingEndTime - processingStartTime // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay = { + def totalDelay: Long = { processingEndTime - time.milliseconds } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b36aeb341d25e..98900473138fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -72,7 +72,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ - def start() = synchronized { + def start(): Unit = synchronized { if (actor != null) { throw new SparkException("ReceiverTracker already started") } @@ -86,7 +86,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Stop the receiver execution thread. */ - def stop(graceful: Boolean) = synchronized { + def stop(graceful: Boolean): Unit = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers if (!skipReceiverLaunch) receiverExecutor.stop(graceful) @@ -201,7 +201,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Actor to receive messages from the receivers. */ private class ReceiverTrackerActor extends Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverActor) => registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true @@ -244,16 +244,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (graceful) { val pollTime = 100 - def done = { receiverInfo.isEmpty && !running } logInfo("Waiting for receiver job to terminate gracefully") - while(!done) { + while (receiverInfo.nonEmpty || running) { Thread.sleep(pollTime) } logInfo("Waited for receiver job to terminate gracefully") } // Check if all the receivers have been deregistered or not - if (!receiverInfo.isEmpty) { + if (receiverInfo.nonEmpty) { logWarning("Not all of the receivers have deregistered, " + receiverInfo) } else { logInfo("All of the receivers have deregistered successfully") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 5ee53a5c5f561..e4bd067cacb77 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable.{Queue, HashMap} + import org.apache.spark.streaming.{Time, StreamingContext} import org.apache.spark.streaming.scheduler._ -import scala.collection.mutable.{Queue, HashMap} import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted import org.apache.spark.streaming.scheduler.BatchInfo @@ -59,11 +60,13 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + synchronized { + runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + } } - override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) @@ -72,19 +75,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { - waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) - runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() - totalCompletedBatches += 1L - - batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalProcessedRecords += infos.map(_.numRecords).sum + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + synchronized { + waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) + runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) + completedaBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + totalCompletedBatches += 1L + + batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => + totalProcessedRecords += infos.map(_.numRecords).sum + } } } - def numReceivers = synchronized { + def numReceivers: Int = synchronized { ssc.graph.getReceiverInputStreams().size } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index a73d6f3bf0661..4d968f8bfa7a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.util.collection.OpenHashMap -import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { @@ -71,7 +69,7 @@ object RawTextHelper { var count = 0 while(data.hasNext) { - value = data.next + value = data.next() if (value != null) { count += 1 if (len == 0) { @@ -108,9 +106,13 @@ object RawTextHelper { } } - def add(v1: Long, v2: Long) = (v1 + v2) + def add(v1: Long, v2: Long): Long = { + v1 + v2 + } - def subtract(v1: Long, v2: Long) = (v1 - v2) + def subtract(v1: Long, v2: Long): Long = { + v1 - v2 + } - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long): Long = math.max(v1, v2) }