diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 536aedb6f9fe9..f979ffa16641a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -91,7 +91,7 @@ private[deploy] class ApplicationInfo(
}
}
- private[master] val requestedCores = desc.maxCores.getOrElse(defaultCores)
+ private val requestedCores = desc.maxCores.getOrElse(defaultCores)
private[master] def coresLeft: Int = requestedCores - coresGranted
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 46509e39c0f23..45412a35e9a7d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -75,16 +75,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val workers = state.workers.sortBy(_.id)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
- val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use",
- "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration")
+ val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
+ "User", "State", "Duration")
val activeApps = state.activeApps.sortBy(_.startTime).reverse
- val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps)
-
- val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node",
- "Submitted Time", "User", "State", "Duration")
+ val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
val completedApps = state.completedApps.sortBy(_.endTime).reverse
- val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow,
- completedApps)
+ val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores",
"Memory", "Main Class")
@@ -191,7 +187,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
- private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = {
+ private def appRow(app: ApplicationInfo): Seq[Node] = {
val killLink = if (parent.killEnabled &&
(app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) {
val killLinkUri = s"app/kill?id=${app.id}&terminate=true"
@@ -201,7 +197,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
(kill)
}
-
{app.id}
@@ -210,15 +205,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
|
{app.desc.name}
|
- {
- if (active) {
-
- {app.coresGranted}
- |
- }
- }
- {if (app.requestedCores == Int.MaxValue) "*" else app.requestedCores}
+ {app.coresGranted}
|
{Utils.megabytesToString(app.desc.memoryPerSlave)}
@@ -230,14 +218,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
|
}
- private def activeAppRow(app: ApplicationInfo): Seq[Node] = {
- appRow(app, active = true)
- }
-
- private def completeAppRow(app: ApplicationInfo): Seq[Node] = {
- appRow(app, active = false)
- }
-
private def driverRow(driver: DriverInfo): Seq[Node] = {
val killLink = if (parent.killEnabled &&
(driver.state == DriverState.RUNNING ||
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 6fa1f2c880f7a..132a9ced77700 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -81,9 +81,11 @@ class TaskInfo(
def status: String = {
if (running) {
- "RUNNING"
- } else if (gettingResult) {
- "GET RESULT"
+ if (gettingResult) {
+ "GET RESULT"
+ } else {
+ "RUNNING"
+ }
} else if (failed) {
"FAILED"
} else if (successful) {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 660df00bc32f5..d0178dfde6935 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -112,6 +112,7 @@ class FileShuffleBlockManager(conf: SparkConf)
private val shuffleState = shuffleStates(shuffleId)
private var fileGroup: ShuffleFileGroup = null
+ val openStartTime = System.nanoTime
val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
fileGroup = getUnusedFileGroup()
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
@@ -135,6 +136,9 @@ class FileShuffleBlockManager(conf: SparkConf)
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
}
}
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, so should be included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
override def releaseWriters(success: Boolean) {
if (consolidateShuffleFiles) {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index fa2e617762f55..55ea0f17b156a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -63,6 +63,9 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter.insertAll(records)
}
+ // Don't bother including the time to open the merged output file in the shuffle write time,
+ // because it just opens a single file, so is typically too fast to measure accurately
+ // (see SPARK-3570).
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index e03442894c5cc..797c9404bc449 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -269,11 +269,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+: getFormattedTimeQuantiles(serializationTimes)
val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- if (info.gettingResultTime > 0) {
- (info.finishTime - info.gettingResultTime).toDouble
- } else {
- 0.0
- }
+ getGettingResultTime(info).toDouble
}
val gettingResultQuantiles =
@@ -464,7 +460,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
- val gettingResultTime = info.gettingResultTime
+ val gettingResultTime = getGettingResultTime(info)
val maybeAccumulators = info.accumulables
val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"}
@@ -627,6 +623,19 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
| {errorSummary}{details} |
}
+ private def getGettingResultTime(info: TaskInfo): Long = {
+ if (info.gettingResultTime > 0) {
+ if (info.finishTime > 0) {
+ info.finishTime - info.gettingResultTime
+ } else {
+ // The task is still fetching the result.
+ System.currentTimeMillis - info.gettingResultTime
+ }
+ } else {
+ 0L
+ }
+ }
+
private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
val totalExecutionTime =
if (info.gettingResult) {
@@ -638,6 +647,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
val executorOverhead = (metrics.executorDeserializeTime +
metrics.resultSerializationTime)
- math.max(0, totalExecutionTime - metrics.executorRunTime - executorOverhead)
+ math.max(
+ 0,
+ totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info))
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index d9a671687aad0..0b5a914e7dbbf 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1876,6 +1876,10 @@ private[spark] object Utils extends Logging {
startService: Int => (T, Int),
conf: SparkConf,
serviceName: String = ""): (T, Int) = {
+
+ require(startPort == 0 || (1024 <= startPort && startPort < 65536),
+ "startPort should be between 1024 and 65535 (inclusive), or 0 for a random free port.")
+
val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
val maxRetries = portMaxRetries(conf)
for (offset <- 0 to maxRetries) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3262e670c2030..b962c101c91da 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -352,6 +352,7 @@ private[spark] class ExternalSorter[K, V, C](
// Create our file writers if we haven't done so yet
if (partitionWriters == null) {
curWriteMetrics = new ShuffleWriteMetrics()
+ val openStartTime = System.nanoTime
partitionWriters = Array.fill(numPartitions) {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
@@ -359,6 +360,10 @@ private[spark] class ExternalSorter[K, V, C](
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
}
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
}
// No need to sort stuff, just write each element out
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index c52591b352340..efc2482c74ddf 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+ /** Tests whether this map contains a binding for a key. */
+ def contains(k: K): Boolean = {
+ if (k == null) {
+ haveNullValue
+ } else {
+ _keySet.getPos(k) != OpenHashSet.INVALID_POS
+ }
+ }
+
/** Get the value for a given key */
def apply(k: K): V = {
if (k == null) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index c80057f95e0b2..1501111a06655 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -122,7 +122,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
*/
def addWithoutResize(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
- var i = 1
+ var delta = 1
while (true) {
if (!_bitset.get(pos)) {
// This is a new key.
@@ -134,14 +134,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
// Found an existing key.
return pos
} else {
- val delta = i
+ // quadratic probing with values increase by 1, 2, 3, ...
pos = (pos + delta) & _mask
- i += 1
+ delta += 1
}
}
- // Never reached here
- assert(INVALID_POS != INVALID_POS)
- INVALID_POS
+ throw new RuntimeException("Should never reach here.")
}
/**
@@ -163,21 +161,19 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
*/
def getPos(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
- var i = 1
- val maxProbe = _data.size
- while (i < maxProbe) {
+ var delta = 1
+ while (true) {
if (!_bitset.get(pos)) {
return INVALID_POS
} else if (k == _data(pos)) {
return pos
} else {
- val delta = i
+ // quadratic probing with values increase by 1, 2, 3, ...
pos = (pos + delta) & _mask
- i += 1
+ delta += 1
}
}
- // Never reached here
- INVALID_POS
+ throw new RuntimeException("Should never reach here.")
}
/** Return the value at the specified position. */
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 61e22642761f0..b4ec4ea521253 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -48,6 +48,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
override def size: Int = _keySet.size
+ /** Tests whether this map contains a binding for a key. */
+ def contains(k: K): Boolean = {
+ _keySet.getPos(k) != OpenHashSet.INVALID_POS
+ }
+
/** Get the value for a given key */
def apply(k: K): V = {
val pos = _keySet.getPos(k)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index 6a70877356409..ef890d2ba60f3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers {
assert(map(i.toString) === i.toString)
}
}
+
+ test("contains") {
+ val map = new OpenHashMap[String, Int](2)
+ map("a") = 1
+ assert(map.contains("a"))
+ assert(!map.contains("b"))
+ assert(!map.contains(null))
+ map(null) = 0
+ assert(map.contains(null))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index 8c7df7d73dcd3..caf378fec8b3e 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers {
assert(map(i.toLong) === i.toString)
}
}
+
+ test("contains") {
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](1)
+ map(0) = 0
+ assert(map.contains(0))
+ assert(!map.contains(1))
+ }
}
diff --git a/dev/run-tests b/dev/run-tests
index d6935a61c6d29..561d7fc9e7b1f 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -178,6 +178,15 @@ CURRENT_BLOCK=$BLOCK_BUILD
fi
}
+echo ""
+echo "========================================================================="
+echo "Detecting binary incompatibilities with MiMa"
+echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_MIMA
+
+./dev/mima
+
echo ""
echo "========================================================================="
echo "Running Spark unit tests"
@@ -227,12 +236,3 @@ echo "========================================================================="
CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
./python/run-tests
-
-echo ""
-echo "========================================================================="
-echo "Detecting binary incompatibilities with MiMa"
-echo "========================================================================="
-
-CURRENT_BLOCK=$BLOCK_MIMA
-
-./dev/mima
diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh
index 1348e0609dda4..8ab6db6925d6e 100644
--- a/dev/run-tests-codes.sh
+++ b/dev/run-tests-codes.sh
@@ -22,6 +22,6 @@ readonly BLOCK_RAT=11
readonly BLOCK_SCALA_STYLE=12
readonly BLOCK_PYTHON_STYLE=13
readonly BLOCK_BUILD=14
-readonly BLOCK_SPARK_UNIT_TESTS=15
-readonly BLOCK_PYSPARK_UNIT_TESTS=16
-readonly BLOCK_MIMA=17
+readonly BLOCK_MIMA=15
+readonly BLOCK_SPARK_UNIT_TESTS=16
+readonly BLOCK_PYSPARK_UNIT_TESTS=17
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 5f4000e83925c..3a937b637e003 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -199,12 +199,12 @@ done
failing_test="Python style tests"
elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then
failing_test="to build"
+ elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
+ failing_test="MiMa tests"
elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then
failing_test="Spark unit tests"
elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then
failing_test="PySpark unit tests"
- elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
- failing_test="MiMa tests"
else
failing_test="some tests"
fi
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index da6aef7f14c4c..c08c76d226713 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -408,31 +408,31 @@ import org.apache.spark.sql.SQLContext;
// Labeled and unlabeled instance types.
// Spark SQL can infer schema from Java Beans.
public class Document implements Serializable {
- private Long id;
+ private long id;
private String text;
- public Document(Long id, String text) {
+ public Document(long id, String text) {
this.id = id;
this.text = text;
}
- public Long getId() { return this.id; }
- public void setId(Long id) { this.id = id; }
+ public long getId() { return this.id; }
+ public void setId(long id) { this.id = id; }
public String getText() { return this.text; }
public void setText(String text) { this.text = text; }
}
public class LabeledDocument extends Document implements Serializable {
- private Double label;
+ private double label;
- public LabeledDocument(Long id, String text, Double label) {
+ public LabeledDocument(long id, String text, double label) {
super(id, text);
this.label = label;
}
- public Double getLabel() { return this.label; }
- public void setLabel(Double label) { this.label = label; }
+ public double getLabel() { return this.label; }
+ public void setLabel(double label) { this.label = label; }
}
// Set up contexts.
@@ -565,6 +565,11 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
+// Labeled and unlabeled instance types.
+// Spark SQL can infer schema from case classes.
+case class LabeledDocument(id: Long, text: String, label: Double)
+case class Document(id: Long, text: String)
+
val conf = new SparkConf().setAppName("CrossValidatorExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
@@ -655,6 +660,36 @@ import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
+// Labeled and unlabeled instance types.
+// Spark SQL can infer schema from Java Beans.
+public class Document implements Serializable {
+ private long id;
+ private String text;
+
+ public Document(long id, String text) {
+ this.id = id;
+ this.text = text;
+ }
+
+ public long getId() { return this.id; }
+ public void setId(long id) { this.id = id; }
+
+ public String getText() { return this.text; }
+ public void setText(String text) { this.text = text; }
+}
+
+public class LabeledDocument extends Document implements Serializable {
+ private double label;
+
+ public LabeledDocument(long id, String text, double label) {
+ super(id, text);
+ this.label = label;
+ }
+
+ public double getLabel() { return this.label; }
+ public void setLabel(double label) { this.label = label; }
+}
+
SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 5fe832b6fa100..f5b775da7930a 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -237,9 +237,13 @@ You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`.
the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support:
{% highlight bash %}
-$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark
{% endhighlight %}
+After the IPython Notebook server is launched, you can create a new "Python 2" notebook from
+the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of
+your notebook before you start to try Spark from the IPython notebook.
+
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 68b1aeb8ebd01..d9f3eb2b74b18 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -274,6 +274,6 @@ If you need a reference to the proper location to put log files in the YARN so t
# Important notes
- Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured.
-- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored.
+- In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do.
- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN.
- The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 6a333fdb562a7..c99a0b03442c4 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -624,7 +624,8 @@ tuples or lists in the RDD created in the step 1.
For example:
{% highlight python %}
# Import SQLContext and data types
-from pyspark.sql import *
+from pyspark.sql import SQLContext
+from pyspark.sql.types import *
# sc is an existing SparkContext.
sqlContext = SQLContext(sc)
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 322de7bf2fed8..51d273af8da84 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -28,6 +28,7 @@ import scala.language.postfixOps
import com.google.common.base.Charsets
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
+import org.apache.commons.lang3.RandomUtils
import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
import org.jboss.netty.channel.ChannelPipeline
@@ -40,7 +41,6 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
-import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted}
import org.apache.spark.util.Utils
class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
@@ -76,7 +76,8 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
/** Find a free port */
private def findFreePort(): Int = {
- Utils.startServiceOnPort(23456, (trialPort: Int) => {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index 0f3298af6234a..24d78ecb3a97d 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -25,6 +25,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.activemq.broker.{TransportConnector, BrokerService}
+import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
@@ -113,7 +114,8 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
}
private def findFreePort(): Int = {
- Utils.startServiceOnPort(23456, (trialPort: Int) => {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
val socket = new ServerSocket(trialPort)
socket.close()
(null, trialPort)
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index dc90e9e987234..2da5f7278729e 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -147,7 +147,6 @@ void addOptionString(List cmd, String options) {
*/
List buildClassPath(String appClassPath) throws IOException {
String sparkHome = getSparkHome();
- String scala = getScalaVersion();
List cp = new ArrayList();
addToClassPath(cp, getenv("SPARK_CLASSPATH"));
@@ -158,6 +157,7 @@ List buildClassPath(String appClassPath) throws IOException {
boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES"));
boolean isTesting = "1".equals(getenv("SPARK_TESTING"));
if (prependClasses || isTesting) {
+ String scala = getScalaVersion();
List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx",
"streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver",
"yarn", "launcher");
@@ -182,7 +182,7 @@ List buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome));
}
- String assembly = findAssembly(scala);
+ String assembly = findAssembly();
addToClassPath(cp, assembly);
// When Hive support is needed, Datanucleus jars must be included on the classpath. Datanucleus
@@ -330,7 +330,7 @@ String getenv(String key) {
return firstNonEmpty(childEnv.get(key), System.getenv(key));
}
- private String findAssembly(String scalaVersion) {
+ private String findAssembly() {
String sparkHome = getSparkHome();
File libdir;
if (new File(sparkHome, "RELEASE").isFile()) {
@@ -338,7 +338,7 @@ private String findAssembly(String scalaVersion) {
checkState(libdir.isDirectory(), "Library directory '%s' does not exist.",
libdir.getAbsolutePath());
} else {
- libdir = new File(sparkHome, String.format("assembly/target/scala-%s", scalaVersion));
+ libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion()));
}
final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar");
diff --git a/repl/pom.xml b/repl/pom.xml
index edfa1c7f2c29c..03053b4c3b287 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -84,6 +84,11 @@
scalacheck_${scala.binary.version}
test
+
+ org.mockito
+ mockito-all
+ test
+
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 9805609120005..004941d5f50ae 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -17,9 +17,10 @@
package org.apache.spark.repl
-import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException}
-import java.net.{URI, URL, URLEncoder}
-import java.util.concurrent.{Executors, ExecutorService}
+import java.io.{IOException, ByteArrayOutputStream, InputStream}
+import java.net.{HttpURLConnection, URI, URL, URLEncoder}
+
+import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
val parentLoader = new ParentClassLoader(parent)
+ // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
+ private[repl] var httpUrlConnectionTimeoutMillis: Int = -1
+
// Hadoop FileSystem object for our URI, if it isn't using HTTP
var fileSystem: FileSystem = {
if (Set("http", "https", "ftp").contains(uri.getScheme)) {
@@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
}
}
+ private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
+ val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
+ val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
+ val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
+ newuri.toURL
+ } else {
+ new URL(classUri + "/" + urlEncode(pathInDirectory))
+ }
+ val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
+ SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
+ // Set the connection timeouts (for testing purposes)
+ if (httpUrlConnectionTimeoutMillis != -1) {
+ connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
+ connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
+ }
+ connection.connect()
+ try {
+ if (connection.getResponseCode != 200) {
+ // Close the error stream so that the connection is eligible for re-use
+ try {
+ connection.getErrorStream.close()
+ } catch {
+ case ioe: IOException =>
+ logError("Exception while closing error stream", ioe)
+ }
+ throw new ClassNotFoundException(s"Class file not found at URL $url")
+ } else {
+ connection.getInputStream
+ }
+ } catch {
+ case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
+ connection.disconnect()
+ throw e
+ }
+ }
+
+ private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
+ val path = new Path(directory, pathInDirectory)
+ if (fileSystem.exists(path)) {
+ fileSystem.open(path)
+ } else {
+ throw new ClassNotFoundException(s"Class file not found at path $path")
+ }
+ }
+
def findClassLocally(name: String): Option[Class[_]] = {
+ val pathInDirectory = name.replace('.', '/') + ".class"
+ var inputStream: InputStream = null
try {
- val pathInDirectory = name.replace('.', '/') + ".class"
- val inputStream = {
+ inputStream = {
if (fileSystem != null) {
- fileSystem.open(new Path(directory, pathInDirectory))
+ getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
- val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
- val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
- val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
- newuri.toURL
- } else {
- new URL(classUri + "/" + urlEncode(pathInDirectory))
- }
-
- Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager)
- .getInputStream
+ getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream)
- inputStream.close()
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
- case e: FileNotFoundException =>
+ case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
@@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
+ } finally {
+ if (inputStream != null) {
+ try {
+ inputStream.close()
+ } catch {
+ case e: Exception =>
+ logError("Exception while closing inputStream", e)
+ }
+ }
}
}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index 6a79e76a34db8..c709cde740748 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -20,13 +20,25 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URL, URLClassLoader}
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
+import org.scalatest.concurrent.Interruptor
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito._
-import org.apache.spark.{SparkConf, TestUtils}
+import org.apache.spark._
import org.apache.spark.util.Utils
-class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
+class ExecutorClassLoaderSuite
+ extends FunSuite
+ with BeforeAndAfterAll
+ with MockitoSugar
+ with Logging {
val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
@@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
var tempDir2: File = _
var url1: String = _
var urls2: Array[URL] = _
+ var classServer: HttpServer = _
override def beforeAll() {
super.beforeAll()
@@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
override def afterAll() {
super.afterAll()
+ if (classServer != null) {
+ classServer.stop()
+ }
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
+ SparkEnv.set(null)
}
test("child first") {
@@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
}
}
+ test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
+ // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
+ // from the driver's class server would leak a HTTP connection, causing the class server's
+ // thread / connection pool to be exhausted.
+ val conf = new SparkConf()
+ val securityManager = new SecurityManager(conf)
+ classServer = new HttpServer(conf, tempDir1, securityManager)
+ classServer.start()
+ // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
+ val mockEnv = mock[SparkEnv]
+ when(mockEnv.securityManager).thenReturn(securityManager)
+ SparkEnv.set(mockEnv)
+ // Create an ExecutorClassLoader that's configured to load classes from the HTTP server
+ val parentLoader = new URLClassLoader(Array.empty, null)
+ val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
+ classLoader.httpUrlConnectionTimeoutMillis = 500
+ // Check that this class loader can actually load classes that exist
+ val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
+ val fakeClassVersion = fakeClass.toString
+ assert(fakeClassVersion === "1")
+ // Try to perform a full GC now, since GC during the test might mask resource leaks
+ System.gc()
+ // When the original bug occurs, the test thread becomes blocked in a classloading call
+ // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
+ // shut down the HTTP server when the test times out
+ val interruptor: Interruptor = new Interruptor {
+ override def apply(thread: Thread): Unit = {
+ classServer.stop()
+ classServer = null
+ thread.interrupt()
+ }
+ }
+ def tryAndFailToLoadABunchOfClasses(): Unit = {
+ // The number of trials here should be much larger than Jetty's thread / connection limit
+ // in order to expose thread or connection leaks
+ for (i <- 1 to 1000) {
+ if (Thread.currentThread().isInterrupted) {
+ throw new InterruptedException()
+ }
+ // Incorporate the iteration number into the class name in order to avoid any response
+ // caching that might be added in the future
+ intercept[ClassNotFoundException] {
+ classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
+ }
+ }
+ }
+ failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
+ }
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 15add84878ecf..34fedead44db3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -30,6 +30,12 @@ class AnalysisException protected[sql] (
val startPosition: Option[Int] = None)
extends Exception with Serializable {
+ def withPosition(line: Option[Int], startPosition: Option[Int]) = {
+ val newException = new AnalysisException(message, line, startPosition)
+ newException.setStackTrace(getStackTrace)
+ newException
+ }
+
override def getMessage: String = {
val lineAnnotation = line.map(l => s" line $l").getOrElse("")
val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index 366be00473d1c..3823584287741 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.plans.logical._
private[sql] object KeywordNormalizer {
- def apply(str: String) = str.toLowerCase()
+ def apply(str: String): String = str.toLowerCase()
}
private[sql] abstract class AbstractSparkSQLParser
@@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser
}
protected case class Keyword(str: String) {
- def normalize = KeywordNormalizer(str)
+ def normalize: String = KeywordNormalizer(str)
def parser: Parser[String] = normalize
}
@@ -81,7 +81,7 @@ private[sql] abstract class AbstractSparkSQLParser
class SqlLexical extends StdLexical {
case class FloatLit(chars: String) extends Token {
- override def toString = chars
+ override def toString: String = chars
}
/* This is a work around to support the lazy setting */
@@ -120,7 +120,7 @@ class SqlLexical extends StdLexical {
| failure("illegal character")
)
- override def identChar = letter | elem('_')
+ override def identChar: Parser[Elem] = letter | elem('_')
override def whitespace: Parser[Any] =
( whitespaceChar
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 92d3db077c5e1..44eceb0b372e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -64,9 +64,7 @@ class Analyzer(catalog: Catalog,
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
typeCoercionRules ++
- extendedResolutionRules : _*),
- Batch("Remove SubQueries", fixedPoint,
- EliminateSubQueries)
+ extendedResolutionRules : _*)
)
/**
@@ -170,12 +168,12 @@ class Analyzer(catalog: Catalog,
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
- def getTable(u: UnresolvedRelation) = {
+ def getTable(u: UnresolvedRelation): LogicalPlan = {
try {
catalog.lookupRelation(u.tableIdentifier, u.alias)
} catch {
case _: NoSuchTableException =>
- u.failAnalysis(s"no such table ${u.tableIdentifier}")
+ u.failAnalysis(s"no such table ${u.tableName}")
}
}
@@ -275,7 +273,8 @@ class Analyzer(catalog: Catalog,
q.asInstanceOf[GroupingAnalytics].gid
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
- val result = q.resolveChildren(name, resolver).getOrElse(u)
+ val result =
+ withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index 9e6e2912e0622..5eb7dff0cede8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -86,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
tables += ((getDbTableName(tableIdent), plan))
}
- override def unregisterTable(tableIdentifier: Seq[String]) = {
+ override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
val tableIdent = processTableIdentifier(tableIdentifier)
tables -= getDbTableName(tableIdent)
}
- override def unregisterAllTables() = {
+ override def unregisterAllTables(): Unit = {
tables.clear()
}
@@ -147,8 +147,8 @@ trait OverrideCatalog extends Catalog {
}
abstract override def lookupRelation(
- tableIdentifier: Seq[String],
- alias: Option[String] = None): LogicalPlan = {
+ tableIdentifier: Seq[String],
+ alias: Option[String] = None): LogicalPlan = {
val tableIdent = processTableIdentifier(tableIdentifier)
val overriddenTable = overrides.get(getDBTable(tableIdent))
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))
@@ -205,15 +205,15 @@ trait OverrideCatalog extends Catalog {
*/
object EmptyCatalog extends Catalog {
- val caseSensitive: Boolean = true
+ override val caseSensitive: Boolean = true
- def tableExists(tableIdentifier: Seq[String]): Boolean = {
+ override def tableExists(tableIdentifier: Seq[String]): Boolean = {
throw new UnsupportedOperationException
}
- def lookupRelation(
- tableIdentifier: Seq[String],
- alias: Option[String] = None) = {
+ override def lookupRelation(
+ tableIdentifier: Seq[String],
+ alias: Option[String] = None): LogicalPlan = {
throw new UnsupportedOperationException
}
@@ -221,11 +221,11 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}
- def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
+ override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
- def unregisterTable(tableIdentifier: Seq[String]): Unit = {
+ override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
throw new UnsupportedOperationException
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index fb975ee5e7296..40472a1cbb3b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -33,7 +33,7 @@ class CheckAnalysis {
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
- def failAnalysis(msg: String) = {
+ def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}
@@ -63,7 +63,7 @@ class CheckAnalysis {
s"filter expression '${f.condition.prettyString}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
- case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) =>
+ case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
@@ -85,13 +85,18 @@ class CheckAnalysis {
cleaned.foreach(checkValidAggregateExpression)
+ case _ => // Fallbacks to the following checks
+ }
+
+ operator match {
case o if o.children.nonEmpty && o.missingInput.nonEmpty =>
- val missingAttributes = o.missingInput.map(_.prettyString).mkString(",")
- val input = o.inputSet.map(_.prettyString).mkString(",")
+ val missingAttributes = o.missingInput.mkString(",")
+ val input = o.inputSet.mkString(",")
- failAnalysis(s"resolved attributes $missingAttributes missing from $input")
+ failAnalysis(
+ s"resolved attribute(s) $missingAttributes missing from $input " +
+ s"in operator ${operator.simpleString}")
- // Catch all
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9f334f6d42ad1..c43ea55899695 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -35,7 +35,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
- def registerFunction(name: String, builder: FunctionBuilder) = {
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
}
@@ -47,7 +47,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry {
val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)
- def registerFunction(name: String, builder: FunctionBuilder) = {
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
}
@@ -61,13 +61,15 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr
* functions are already filled in and the analyser needs only to resolve attribute references.
*/
object EmptyFunctionRegistry extends FunctionRegistry {
- def registerFunction(name: String, builder: FunctionBuilder) = ???
+ override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
+ throw new UnsupportedOperationException
+ }
- def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}
- def caseSensitive: Boolean = ???
+ override def caseSensitive: Boolean = throw new UnsupportedOperationException
}
/**
@@ -76,7 +78,7 @@ object EmptyFunctionRegistry extends FunctionRegistry {
* TODO move this into util folder?
*/
object StringKeyHashMap {
- def apply[T](caseSensitive: Boolean) = caseSensitive match {
+ def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match {
case false => new StringKeyHashMap[T](_.toLowerCase)
case true => new StringKeyHashMap[T](identity)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index e95f19e69ed43..c61c395cb4bb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -38,8 +38,16 @@ package object analysis {
implicit class AnalysisErrorAt(t: TreeNode[_]) {
/** Fails the analysis at the point where a specific tree node was parsed. */
- def failAnalysis(msg: String) = {
+ def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
}
}
+
+ /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */
+ def withPosition[A](t: TreeNode[_])(f: => A) = {
+ try f catch {
+ case a: AnalysisException =>
+ throw a.withPosition(t.origin.line, t.origin.startPosition)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a7cd4124e56f3..300e9ba187bc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.types.DataType
/**
* Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully
@@ -36,7 +37,12 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str
case class UnresolvedRelation(
tableIdentifier: Seq[String],
alias: Option[String] = None) extends LeafNode {
- override def output = Nil
+
+ /** Returns a `.` separated name for this relation. */
+ def tableName: String = tableIdentifier.mkString(".")
+
+ override def output: Seq[Attribute] = Nil
+
override lazy val resolved = false
}
@@ -44,16 +50,16 @@ case class UnresolvedRelation(
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
- override def exprId = throw new UnresolvedException(this, "exprId")
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def nullable = throw new UnresolvedException(this, "nullable")
- override def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance() = this
- override def withNullability(newNullability: Boolean) = this
- override def withQualifiers(newQualifiers: Seq[String]) = this
- override def withName(newName: String) = UnresolvedAttribute(name)
+ override def newInstance(): UnresolvedAttribute = this
+ override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
+ override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
+ override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name)
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
@@ -63,16 +69,16 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
}
case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def foldable = throw new UnresolvedException(this, "foldable")
- override def nullable = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
// Unresolved functions are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString = s"'$name(${children.mkString(",")})"
+ override def toString: String = s"'$name(${children.mkString(",")})"
}
/**
@@ -82,17 +88,17 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
trait Star extends Attribute with trees.LeafNode[Expression] {
self: Product =>
- override def name = throw new UnresolvedException(this, "name")
- override def exprId = throw new UnresolvedException(this, "exprId")
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def nullable = throw new UnresolvedException(this, "nullable")
- override def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override def name: String = throw new UnresolvedException(this, "name")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance() = this
- override def withNullability(newNullability: Boolean) = this
- override def withQualifiers(newQualifiers: Seq[String]) = this
- override def withName(newName: String) = this
+ override def newInstance(): Star = this
+ override def withNullability(newNullability: Boolean): Star = this
+ override def withQualifiers(newQualifiers: Seq[String]): Star = this
+ override def withName(newName: String): Star = this
// Star gets expanded at runtime so we never evaluate a Star.
override def eval(input: Row = null): EvaluatedType =
@@ -125,7 +131,7 @@ case class UnresolvedStar(table: Option[String]) extends Star {
}
}
- override def toString = table.map(_ + ".").getOrElse("") + "*"
+ override def toString: String = table.map(_ + ".").getOrElse("") + "*"
}
/**
@@ -140,25 +146,25 @@ case class UnresolvedStar(table: Option[String]) extends Star {
case class MultiAlias(child: Expression, names: Seq[String])
extends Attribute with trees.UnaryNode[Expression] {
- override def name = throw new UnresolvedException(this, "name")
+ override def name: String = throw new UnresolvedException(this, "name")
- override def exprId = throw new UnresolvedException(this, "exprId")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
- override def dataType = throw new UnresolvedException(this, "dataType")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
- override def nullable = throw new UnresolvedException(this, "nullable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
- override def qualifiers = throw new UnresolvedException(this, "qualifiers")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override lazy val resolved = false
- override def newInstance() = this
+ override def newInstance(): MultiAlias = this
- override def withNullability(newNullability: Boolean) = this
+ override def withNullability(newNullability: Boolean): MultiAlias = this
- override def withQualifiers(newQualifiers: Seq[String]) = this
+ override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this
- override def withName(newName: String) = this
+ override def withName(newName: String): MultiAlias = this
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
@@ -175,17 +181,17 @@ case class MultiAlias(child: Expression, names: Seq[String])
*/
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
- override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
+ override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
- override def dataType = throw new UnresolvedException(this, "dataType")
- override def foldable = throw new UnresolvedException(this, "foldable")
- override def nullable = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString = s"$child.$fieldName"
+ override def toString: String = s"$child.$fieldName"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 51a09ac0e1249..145f062dd6817 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -61,60 +61,60 @@ package object dsl {
trait ImplicitOperators {
def expr: Expression
- def unary_- = UnaryMinus(expr)
- def unary_! = Not(expr)
- def unary_~ = BitwiseNot(expr)
-
- def + (other: Expression) = Add(expr, other)
- def - (other: Expression) = Subtract(expr, other)
- def * (other: Expression) = Multiply(expr, other)
- def / (other: Expression) = Divide(expr, other)
- def % (other: Expression) = Remainder(expr, other)
- def & (other: Expression) = BitwiseAnd(expr, other)
- def | (other: Expression) = BitwiseOr(expr, other)
- def ^ (other: Expression) = BitwiseXor(expr, other)
-
- def && (other: Expression) = And(expr, other)
- def || (other: Expression) = Or(expr, other)
-
- def < (other: Expression) = LessThan(expr, other)
- def <= (other: Expression) = LessThanOrEqual(expr, other)
- def > (other: Expression) = GreaterThan(expr, other)
- def >= (other: Expression) = GreaterThanOrEqual(expr, other)
- def === (other: Expression) = EqualTo(expr, other)
- def <=> (other: Expression) = EqualNullSafe(expr, other)
- def !== (other: Expression) = Not(EqualTo(expr, other))
-
- def in(list: Expression*) = In(expr, list)
-
- def like(other: Expression) = Like(expr, other)
- def rlike(other: Expression) = RLike(expr, other)
- def contains(other: Expression) = Contains(expr, other)
- def startsWith(other: Expression) = StartsWith(expr, other)
- def endsWith(other: Expression) = EndsWith(expr, other)
- def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ def unary_- : Expression= UnaryMinus(expr)
+ def unary_! : Predicate = Not(expr)
+ def unary_~ : Expression = BitwiseNot(expr)
+
+ def + (other: Expression): Expression = Add(expr, other)
+ def - (other: Expression): Expression = Subtract(expr, other)
+ def * (other: Expression): Expression = Multiply(expr, other)
+ def / (other: Expression): Expression = Divide(expr, other)
+ def % (other: Expression): Expression = Remainder(expr, other)
+ def & (other: Expression): Expression = BitwiseAnd(expr, other)
+ def | (other: Expression): Expression = BitwiseOr(expr, other)
+ def ^ (other: Expression): Expression = BitwiseXor(expr, other)
+
+ def && (other: Expression): Predicate = And(expr, other)
+ def || (other: Expression): Predicate = Or(expr, other)
+
+ def < (other: Expression): Predicate = LessThan(expr, other)
+ def <= (other: Expression): Predicate = LessThanOrEqual(expr, other)
+ def > (other: Expression): Predicate = GreaterThan(expr, other)
+ def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other)
+ def === (other: Expression): Predicate = EqualTo(expr, other)
+ def <=> (other: Expression): Predicate = EqualNullSafe(expr, other)
+ def !== (other: Expression): Predicate = Not(EqualTo(expr, other))
+
+ def in(list: Expression*): Expression = In(expr, list)
+
+ def like(other: Expression): Expression = Like(expr, other)
+ def rlike(other: Expression): Expression = RLike(expr, other)
+ def contains(other: Expression): Expression = Contains(expr, other)
+ def startsWith(other: Expression): Expression = StartsWith(expr, other)
+ def endsWith(other: Expression): Expression = EndsWith(expr, other)
+ def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression =
Substring(expr, pos, len)
- def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression =
Substring(expr, pos, len)
- def isNull = IsNull(expr)
- def isNotNull = IsNotNull(expr)
+ def isNull: Predicate = IsNull(expr)
+ def isNotNull: Predicate = IsNotNull(expr)
- def getItem(ordinal: Expression) = GetItem(expr, ordinal)
- def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)
+ def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
+ def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
- def cast(to: DataType) = Cast(expr, to)
+ def cast(to: DataType): Expression = Cast(expr, to)
- def asc = SortOrder(expr, Ascending)
- def desc = SortOrder(expr, Descending)
+ def asc: SortOrder = SortOrder(expr, Ascending)
+ def desc: SortOrder = SortOrder(expr, Descending)
- def as(alias: String) = Alias(expr, alias)()
- def as(alias: Symbol) = Alias(expr, alias.name)()
+ def as(alias: String): NamedExpression = Alias(expr, alias)()
+ def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
}
trait ExpressionConversions {
implicit class DslExpression(e: Expression) extends ImplicitOperators {
- def expr = e
+ def expr: Expression = e
}
implicit def booleanToLiteral(b: Boolean): Literal = Literal(b)
@@ -144,94 +144,100 @@ package object dsl {
}
}
- def sum(e: Expression) = Sum(e)
- def sumDistinct(e: Expression) = SumDistinct(e)
- def count(e: Expression) = Count(e)
- def countDistinct(e: Expression*) = CountDistinct(e)
- def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
- def avg(e: Expression) = Average(e)
- def first(e: Expression) = First(e)
- def last(e: Expression) = Last(e)
- def min(e: Expression) = Min(e)
- def max(e: Expression) = Max(e)
- def upper(e: Expression) = Upper(e)
- def lower(e: Expression) = Lower(e)
- def sqrt(e: Expression) = Sqrt(e)
- def abs(e: Expression) = Abs(e)
-
- implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
+ def sum(e: Expression): Expression = Sum(e)
+ def sumDistinct(e: Expression): Expression = SumDistinct(e)
+ def count(e: Expression): Expression = Count(e)
+ def countDistinct(e: Expression*): Expression = CountDistinct(e)
+ def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
+ ApproxCountDistinct(e, rsd)
+ def avg(e: Expression): Expression = Average(e)
+ def first(e: Expression): Expression = First(e)
+ def last(e: Expression): Expression = Last(e)
+ def min(e: Expression): Expression = Min(e)
+ def max(e: Expression): Expression = Max(e)
+ def upper(e: Expression): Expression = Upper(e)
+ def lower(e: Expression): Expression = Lower(e)
+ def sqrt(e: Expression): Expression = Sqrt(e)
+ def abs(e: Expression): Expression = Abs(e)
+
+ implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
override def expr: Expression = Literal(s)
- def attr = analysis.UnresolvedAttribute(s)
+ def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s)
}
abstract class ImplicitAttribute extends ImplicitOperators {
def s: String
- def expr = attr
- def attr = analysis.UnresolvedAttribute(s)
+ def expr: UnresolvedAttribute = attr
+ def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s)
/** Creates a new AttributeReference of type boolean */
- def boolean = AttributeReference(s, BooleanType, nullable = true)()
+ def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)()
/** Creates a new AttributeReference of type byte */
- def byte = AttributeReference(s, ByteType, nullable = true)()
+ def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)()
/** Creates a new AttributeReference of type short */
- def short = AttributeReference(s, ShortType, nullable = true)()
+ def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)()
/** Creates a new AttributeReference of type int */
- def int = AttributeReference(s, IntegerType, nullable = true)()
+ def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)()
/** Creates a new AttributeReference of type long */
- def long = AttributeReference(s, LongType, nullable = true)()
+ def long: AttributeReference = AttributeReference(s, LongType, nullable = true)()
/** Creates a new AttributeReference of type float */
- def float = AttributeReference(s, FloatType, nullable = true)()
+ def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)()
/** Creates a new AttributeReference of type double */
- def double = AttributeReference(s, DoubleType, nullable = true)()
+ def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)()
/** Creates a new AttributeReference of type string */
- def string = AttributeReference(s, StringType, nullable = true)()
+ def string: AttributeReference = AttributeReference(s, StringType, nullable = true)()
/** Creates a new AttributeReference of type date */
- def date = AttributeReference(s, DateType, nullable = true)()
+ def date: AttributeReference = AttributeReference(s, DateType, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)()
+ def decimal: AttributeReference =
+ AttributeReference(s, DecimalType.Unlimited, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal(precision: Int, scale: Int) =
+ def decimal(precision: Int, scale: Int): AttributeReference =
AttributeReference(s, DecimalType(precision, scale), nullable = true)()
/** Creates a new AttributeReference of type timestamp */
- def timestamp = AttributeReference(s, TimestampType, nullable = true)()
+ def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()
/** Creates a new AttributeReference of type binary */
- def binary = AttributeReference(s, BinaryType, nullable = true)()
+ def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()
/** Creates a new AttributeReference of type array */
- def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()
+ def array(dataType: DataType): AttributeReference =
+ AttributeReference(s, ArrayType(dataType), nullable = true)()
/** Creates a new AttributeReference of type map */
def map(keyType: DataType, valueType: DataType): AttributeReference =
map(MapType(keyType, valueType))
- def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()
+
+ def map(mapType: MapType): AttributeReference =
+ AttributeReference(s, mapType, nullable = true)()
/** Creates a new AttributeReference of type struct */
def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
- def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
+ def struct(structType: StructType): AttributeReference =
+ AttributeReference(s, structType, nullable = true)()
}
implicit class DslAttribute(a: AttributeReference) {
- def notNull = a.withNullability(false)
- def nullable = a.withNullability(true)
+ def notNull: AttributeReference = a.withNullability(false)
+ def nullable: AttributeReference = a.withNullability(true)
// Protobuf terminology
- def required = a.withNullability(false)
+ def required: AttributeReference = a.withNullability(false)
- def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
+ def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable)
}
}
@@ -241,23 +247,23 @@ package object dsl {
abstract class LogicalPlanFunctions {
def logicalPlan: LogicalPlan
- def select(exprs: NamedExpression*) = Project(exprs, logicalPlan)
+ def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
- def where(condition: Expression) = Filter(condition, logicalPlan)
+ def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
- def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
+ def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
def join(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
- condition: Option[Expression] = None) =
+ condition: Option[Expression] = None): LogicalPlan =
Join(logicalPlan, otherPlan, joinType, condition)
- def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan)
+ def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan)
- def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan)
+ def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan)
- def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = {
+ def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = {
val aliasedExprs = aggregateExprs.map {
case ne: NamedExpression => ne
case e => Alias(e, e.toString)()
@@ -265,41 +271,43 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}
- def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan)
+ def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan)
- def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan)
+ def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
- def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) =
+ def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
def sample(
fraction: Double,
withReplacement: Boolean = true,
- seed: Int = (math.random * 1000).toInt) =
+ seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
- alias: Option[String] = None) =
+ alias: Option[String] = None): LogicalPlan =
Generate(generator, join, outer, None, logicalPlan)
- def insertInto(tableName: String, overwrite: Boolean = false) =
+ def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite)
- def analyze = analysis.SimpleAnalyzer(logicalPlan)
+ def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan))
}
object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions {
- def writeToFile(path: String) = WriteToFile(path, logicalPlan)
+ def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan)
}
}
case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
- def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+ def call(args: Expression*): ScalaUdf = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+ }
}
// scalastyle:off
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 82e760b6c6916..96a11e352ec50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions
* of the name, or the expected nullability).
*/
object AttributeMap {
- def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
+ def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
+ new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
+ }
}
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index adaeab0b5c027..f9ae85a5cfc1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -19,27 +19,27 @@ package org.apache.spark.sql.catalyst.expressions
protected class AttributeEquals(val a: Attribute) {
- override def hashCode() = a match {
+ override def hashCode(): Int = a match {
case ar: AttributeReference => ar.exprId.hashCode()
case a => a.hashCode()
}
- override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
+ override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
}
}
object AttributeSet {
- def apply(a: Attribute) =
- new AttributeSet(Set(new AttributeEquals(a)))
+ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
- def apply(baseSet: Seq[Expression]) =
+ def apply(baseSet: Seq[Expression]): AttributeSet = {
new AttributeSet(
baseSet
.flatMap(_.references)
.map(new AttributeEquals(_)).toSet)
+ }
}
/**
@@ -57,7 +57,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
extends Traversable[Attribute] with Serializable {
/** Returns true if the members of this AttributeSet and other are the same. */
- override def equals(other: Any) = other match {
+ override def equals(other: Any): Boolean = other match {
case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
case _ => false
}
@@ -81,32 +81,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
* Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in
* `other`.
*/
- def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet)
+ def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet)
/**
* Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found
* in `other`.
*/
- def --(other: Traversable[NamedExpression]) =
+ def --(other: Traversable[NamedExpression]): AttributeSet =
new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute)))
/**
* Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found
* in `other`.
*/
- def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet)
+ def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet)
/**
* Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to
* true.
*/
- override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a)))
+ override def filter(f: Attribute => Boolean): AttributeSet =
+ new AttributeSet(baseSet.filter(ae => f(ae.a)))
/**
* Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in
* `this` and `other`.
*/
- def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet))
+ def intersect(other: AttributeSet): AttributeSet =
+ new AttributeSet(baseSet.intersect(other.baseSet))
override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 76a9f08dea85f..2225621dbaabd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
type EvaluatedType = Any
- override def toString = s"input[$ordinal]"
+ override def toString: String = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index b1bc858478ee1..9bde74ac22669 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -29,9 +29,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
- override def foldable = child.foldable
+ override def foldable: Boolean = child.foldable
- override def nullable = forceNullable(child.dataType, dataType) || child.nullable
+ override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable
private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
@@ -103,7 +103,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
}
- override def toString = s"CAST($child, $dataType)"
+ override def toString: String = s"CAST($child, $dataType)"
type EvaluatedType = Any
@@ -430,14 +430,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
object Cast {
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
- override def initialValue() = {
+ override def initialValue(): SimpleDateFormat = {
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
}
}
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] {
- override def initialValue() = {
+ override def initialValue(): SimpleDateFormat = {
new SimpleDateFormat("yyyy-MM-dd")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6ad39b8372cfb..4e3bbc06a5b4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -65,7 +65,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
- def childrenResolved = !children.exists(!_.resolved)
+ def childrenResolved: Boolean = !children.exists(!_.resolved)
/**
* Returns a string representation of this expression that does not have developer centric
@@ -84,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
def symbol: String
- override def foldable = left.foldable && right.foldable
+ override def foldable: Boolean = left.foldable && right.foldable
- override def toString = s"($left $symbol $right)"
+ override def toString: String = s"($left $symbol $right)"
}
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
@@ -104,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
case class GroupExpression(children: Seq[Expression]) extends Expression {
self: Product =>
type EvaluatedType = Seq[Any]
- override def eval(input: Row): EvaluatedType = ???
- override def nullable = false
- override def foldable = false
- override def dataType = ???
+ override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException
+ override def nullable: Boolean = false
+ override def foldable: Boolean = false
+ override def dataType: DataType = throw new UnsupportedOperationException
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index db5d897ee569f..c2866cd955409 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
new GenericRow(outputArray)
}
- override def toString = s"Row => [${exprArray.mkString(",")}]"
+ override def toString: String = s"Row => [${exprArray.mkString(",")}]"
}
/**
@@ -107,12 +107,12 @@ class JoinedRow extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -142,7 +142,7 @@ class JoinedRow extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -153,7 +153,7 @@ class JoinedRow extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -207,12 +207,12 @@ class JoinedRow2 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -242,7 +242,7 @@ class JoinedRow2 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -253,7 +253,7 @@ class JoinedRow2 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -301,12 +301,12 @@ class JoinedRow3 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -336,7 +336,7 @@ class JoinedRow3 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -347,7 +347,7 @@ class JoinedRow3 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -395,12 +395,12 @@ class JoinedRow4 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -430,7 +430,7 @@ class JoinedRow4 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -441,7 +441,7 @@ class JoinedRow4 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
@@ -489,12 +489,12 @@ class JoinedRow5 extends Row {
override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq
- override def length = row1.length + row2.length
+ override def length: Int = row1.length + row2.length
- override def apply(i: Int) =
+ override def apply(i: Int): Any =
if (i < row1.length) row1(i) else row2(i - row1.length)
- override def isNullAt(i: Int) =
+ override def isNullAt(i: Int): Boolean =
if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length)
override def getInt(i: Int): Int =
@@ -524,7 +524,7 @@ class JoinedRow5 extends Row {
override def getAs[T](i: Int): T =
if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length)
- override def copy() = {
+ override def copy(): Row = {
val totalSize = row1.length + row2.length
val copiedValues = new Array[Any](totalSize)
var i = 0
@@ -535,7 +535,7 @@ class JoinedRow5 extends Row {
new GenericRow(copiedValues)
}
- override def toString() = {
+ override def toString: String = {
// Make sure toString never throws NullPointerException.
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
index b2c6d3029031d..f5fea3f015dc4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
@@ -18,16 +18,19 @@
package org.apache.spark.sql.catalyst.expressions
import java.util.Random
-import org.apache.spark.sql.types.DoubleType
+
+import org.apache.spark.sql.types.{DataType, DoubleType}
case object Rand extends LeafExpression {
- override def dataType = DoubleType
- override def nullable = false
+ override def dataType: DataType = DoubleType
+ override def nullable: Boolean = false
private[this] lazy val rand = new Random
- override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType]
+ override def eval(input: Row = null): EvaluatedType = {
+ rand.nextDouble().asInstanceOf[EvaluatedType]
+ }
- override def toString = "RAND()"
+ override def toString: String = "RAND()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 8a36c6810790d..1fd5ce342b2ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
type EvaluatedType = Any
- def nullable = true
+ override def nullable: Boolean = true
- override def toString = s"scalaUDF(${children.mkString(",")})"
+ override def toString: String = s"scalaUDF(${children.mkString(",")})"
// scalastyle:off
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d00b2ac09745c..83074eb1e6310 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.types.DataType
abstract sealed class SortDirection
case object Ascending extends SortDirection
@@ -31,12 +32,12 @@ case object Descending extends SortDirection
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {
- override def dataType = child.dataType
- override def nullable = child.nullable
+ override def dataType: DataType = child.dataType
+ override def nullable: Boolean = child.nullable
// SortOrder itself is never evaluated.
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
- override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+ override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 21d714c9a8c3b..47b6f358ed1b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable {
var isNull: Boolean = true
def boxed: Any
def update(v: Any)
- def copy(): this.type
+ def copy(): MutableValue
}
final class MutableInt extends MutableValue {
var value: Int = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Int]
+ value = v.asInstanceOf[Int]
}
- def copy() = {
+ override def copy(): MutableInt = {
val newCopy = new MutableInt
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableInt]
}
}
final class MutableFloat extends MutableValue {
var value: Float = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Float]
+ value = v.asInstanceOf[Float]
}
- def copy() = {
+ override def copy(): MutableFloat = {
val newCopy = new MutableFloat
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableFloat]
}
}
final class MutableBoolean extends MutableValue {
var value: Boolean = false
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Boolean]
+ value = v.asInstanceOf[Boolean]
}
- def copy() = {
+ override def copy(): MutableBoolean = {
val newCopy = new MutableBoolean
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableBoolean]
}
}
final class MutableDouble extends MutableValue {
var value: Double = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Double]
+ value = v.asInstanceOf[Double]
}
- def copy() = {
+ override def copy(): MutableDouble = {
val newCopy = new MutableDouble
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableDouble]
}
}
final class MutableShort extends MutableValue {
var value: Short = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = value = {
isNull = false
v.asInstanceOf[Short]
}
- def copy() = {
+ override def copy(): MutableShort = {
val newCopy = new MutableShort
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableShort]
}
}
final class MutableLong extends MutableValue {
var value: Long = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = value = {
isNull = false
v.asInstanceOf[Long]
}
- def copy() = {
+ override def copy(): MutableLong = {
val newCopy = new MutableLong
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableLong]
}
}
final class MutableByte extends MutableValue {
var value: Byte = 0
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = value = {
isNull = false
v.asInstanceOf[Byte]
}
- def copy() = {
+ override def copy(): MutableByte = {
val newCopy = new MutableByte
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableByte]
}
}
final class MutableAny extends MutableValue {
var value: Any = _
- def boxed = if (isNull) null else value
- def update(v: Any) = value = {
+ override def boxed: Any = if (isNull) null else value
+ override def update(v: Any): Unit = {
isNull = false
- v.asInstanceOf[Any]
+ value = v.asInstanceOf[Any]
}
- def copy() = {
+ override def copy(): MutableAny = {
val newCopy = new MutableAny
newCopy.isNull = isNull
newCopy.value = value
- newCopy.asInstanceOf[this.type]
+ newCopy.asInstanceOf[MutableAny]
}
}
@@ -234,9 +234,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
}
- override def setString(ordinal: Int, value: String) = update(ordinal, value)
+ override def setString(ordinal: Int, value: String): Unit = update(ordinal, value)
- override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String]
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5297d1e31246c..30da4faa3f1c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -79,27 +79,29 @@ abstract class AggregateFunction
/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
- override def nullable = base.nullable
- override def dataType = base.dataType
+ override def nullable: Boolean = base.nullable
+ override def dataType: DataType = base.dataType
def update(input: Row): Unit
// Do we really need this?
- override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
+ override def newInstance(): AggregateFunction = {
+ makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
+ }
}
case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"MIN($child)"
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"MIN($child)"
override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
}
- override def newInstance() = new MinFunction(child, this)
+ override def newInstance(): MinFunction = new MinFunction(child, this)
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"MAX($child)"
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"MAX($child)"
override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
}
- override def newInstance() = new MaxFunction(child, this)
+ override def newInstance(): MaxFunction = new MaxFunction(child, this)
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = LongType
- override def toString = s"COUNT($child)"
+ override def nullable: Boolean = false
+ override def dataType: LongType.type = LongType
+ override def toString: String = s"COUNT($child)"
override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
}
- override def newInstance() = new CountFunction(child, this)
+ override def newInstance(): CountFunction = new CountFunction(child, this)
}
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
def this() = this(null)
- override def children = expressions
+ override def children: Seq[Expression] = expressions
- override def nullable = false
- override def dataType = LongType
- override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
- override def newInstance() = new CountDistinctFunction(expressions, this)
+ override def nullable: Boolean = false
+ override def dataType: DataType = LongType
+ override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})"
+ override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this)
- override def asPartial = {
+ override def asPartial: SplitEvaluation = {
val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
SplitEvaluation(
CombineSetsAndCount(partialSet.toAttribute),
@@ -185,11 +187,11 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
def this() = this(null)
- override def children = expressions
- override def nullable = false
- override def dataType = ArrayType(expressions.head.dataType)
- override def toString = s"AddToHashSet(${expressions.mkString(",")})"
- override def newInstance() = new CollectHashSetFunction(expressions, this)
+ override def children: Seq[Expression] = expressions
+ override def nullable: Boolean = false
+ override def dataType: ArrayType = ArrayType(expressions.head.dataType)
+ override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
+ override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this)
}
case class CollectHashSetFunction(
@@ -219,11 +221,13 @@ case class CollectHashSetFunction(
case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
def this() = this(null)
- override def children = inputSet :: Nil
- override def nullable = false
- override def dataType = LongType
- override def toString = s"CombineAndCount($inputSet)"
- override def newInstance() = new CombineSetsAndCountFunction(inputSet, this)
+ override def children: Seq[Expression] = inputSet :: Nil
+ override def nullable: Boolean = false
+ override def dataType: DataType = LongType
+ override def toString: String = s"CombineAndCount($inputSet)"
+ override def newInstance(): CombineSetsAndCountFunction = {
+ new CombineSetsAndCountFunction(inputSet, this)
+ }
}
case class CombineSetsAndCountFunction(
@@ -249,27 +253,31 @@ case class CombineSetsAndCountFunction(
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = child.dataType
- override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
+ override def nullable: Boolean = false
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
+ override def newInstance(): ApproxCountDistinctPartitionFunction = {
+ new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
+ }
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = LongType
- override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
- override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
+ override def nullable: Boolean = false
+ override def dataType: LongType.type = LongType
+ override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
+ override def newInstance(): ApproxCountDistinctMergeFunction = {
+ new ApproxCountDistinctMergeFunction(child, this, relativeSD)
+ }
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = false
- override def dataType = LongType
- override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
+ override def nullable: Boolean = false
+ override def dataType: LongType.type = LongType
+ override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
override def asPartial: SplitEvaluation = {
val partialCount =
@@ -280,14 +288,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
partialCount :: Nil)
}
- override def newInstance() = new CountDistinctFunction(child :: Nil, this)
+ override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
+ override def nullable: Boolean = true
- override def dataType = child.dataType match {
+ override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
case DecimalType.Unlimited =>
@@ -296,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
DoubleType
}
- override def toString = s"AVG($child)"
+ override def toString: String = s"AVG($child)"
override def asPartial: SplitEvaluation = {
child.dataType match {
@@ -323,14 +331,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
}
- override def newInstance() = new AverageFunction(child, this)
+ override def newInstance(): AverageFunction = new AverageFunction(child, this)
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
+ override def nullable: Boolean = true
- override def dataType = child.dataType match {
+ override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
case DecimalType.Unlimited =>
@@ -339,7 +347,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
child.dataType
}
- override def toString = s"SUM($child)"
+ override def toString: String = s"SUM($child)"
override def asPartial: SplitEvaluation = {
child.dataType match {
@@ -357,7 +365,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
}
- override def newInstance() = new SumFunction(child, this)
+ override def newInstance(): SumFunction = new SumFunction(child, this)
}
/**
@@ -377,19 +385,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class CombineSum(child: Expression) extends AggregateExpression {
def this() = this(null)
- override def children = child :: Nil
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"CombineSum($child)"
- override def newInstance() = new CombineSumFunction(child, this)
+ override def children: Seq[Expression] = child :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"CombineSum($child)"
+ override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
case class SumDistinct(child: Expression)
extends PartialAggregate with trees.UnaryNode[Expression] {
def this() = this(null)
- override def nullable = true
- override def dataType = child.dataType match {
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
case DecimalType.Unlimited =>
@@ -397,10 +405,10 @@ case class SumDistinct(child: Expression)
case _ =>
child.dataType
}
- override def toString = s"SUM(DISTINCT ${child})"
- override def newInstance() = new SumDistinctFunction(child, this)
+ override def toString: String = s"SUM(DISTINCT $child)"
+ override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
- override def asPartial = {
+ override def asPartial: SplitEvaluation = {
val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
SplitEvaluation(
CombineSetsAndSum(partialSet.toAttribute, this),
@@ -411,11 +419,13 @@ case class SumDistinct(child: Expression)
case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
def this() = this(null, null)
- override def children = inputSet :: Nil
- override def nullable = true
- override def dataType = base.dataType
- override def toString = s"CombineAndSum($inputSet)"
- override def newInstance() = new CombineSetsAndSumFunction(inputSet, this)
+ override def children: Seq[Expression] = inputSet :: Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = base.dataType
+ override def toString: String = s"CombineAndSum($inputSet)"
+ override def newInstance(): CombineSetsAndSumFunction = {
+ new CombineSetsAndSumFunction(inputSet, this)
+ }
}
case class CombineSetsAndSumFunction(
@@ -449,9 +459,9 @@ case class CombineSetsAndSumFunction(
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"FIRST($child)"
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"FIRST($child)"
override def asPartial: SplitEvaluation = {
val partialFirst = Alias(First(child), "PartialFirst")()
@@ -459,14 +469,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
- override def newInstance() = new FirstFunction(child, this)
+ override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- override def references = child.references
- override def nullable = true
- override def dataType = child.dataType
- override def toString = s"LAST($child)"
+ override def references: AttributeSet = child.references
+ override def nullable: Boolean = true
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"LAST($child)"
override def asPartial: SplitEvaluation = {
val partialLast = Alias(Last(child), "PartialLast")()
@@ -474,7 +484,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode
Last(partialLast.toAttribute),
partialLast :: Nil)
}
- override def newInstance() = new LastFunction(child, this)
+ override def newInstance(): LastFunction = new LastFunction(child, this)
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -713,6 +723,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg
result = input
}
- override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row])
- else null
+ override def eval(input: Row): Any = {
+ if (result != null) expr.eval(result.asInstanceOf[Row]) else null
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 00b0d3c683fe2..1f6526ef66c56 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -24,10 +24,10 @@ import org.apache.spark.sql.types._
case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = child.dataType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"-$child"
+ override def dataType: DataType = child.dataType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"-$child"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -47,10 +47,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
case class Sqrt(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = DoubleType
- override def foldable = child.foldable
- def nullable = true
- override def toString = s"SQRT($child)"
+ override def dataType: DataType = DoubleType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = true
+ override def toString: String = s"SQRT($child)"
lazy val numeric = child.dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -74,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression {
type EvaluatedType = Any
- def nullable = left.nullable || right.nullable
+ def nullable: Boolean = left.nullable || right.nullable
override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType &&
!DecimalType.isFixed(left.dataType)
- def dataType = {
+ def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
@@ -108,7 +108,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "+"
+ override def symbol: String = "+"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -131,7 +131,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "-"
+ override def symbol: String = "-"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -154,7 +154,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "*"
+ override def symbol: String = "*"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
@@ -177,9 +177,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
}
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "/"
+ override def symbol: String = "/"
- override def nullable = true
+ override def nullable: Boolean = true
lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
@@ -203,9 +203,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "%"
+ override def symbol: String = "%"
- override def nullable = true
+ override def nullable: Boolean = true
lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
@@ -232,7 +232,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
* A function that calculates bitwise and(&) of two numbers.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "&"
+ override def symbol: String = "&"
lazy val and: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -253,7 +253,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
* A function that calculates bitwise or(|) of two numbers.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "|"
+ override def symbol: String = "|"
lazy val or: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -274,7 +274,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
* A function that calculates bitwise xor(^) of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
- def symbol = "^"
+ override def symbol: String = "^"
lazy val xor: (Any, Any) => Any = dataType match {
case ByteType =>
@@ -297,10 +297,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseNot(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = child.dataType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"~$child"
+ override def dataType: DataType = child.dataType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"~$child"
lazy val not: (Any) => Any = dataType match {
case ByteType =>
@@ -327,17 +327,17 @@ case class BitwiseNot(child: Expression) extends UnaryExpression {
case class MaxOf(left: Expression, right: Expression) extends Expression {
type EvaluatedType = Any
- override def foldable = left.foldable && right.foldable
+ override def foldable: Boolean = left.foldable && right.foldable
- override def nullable = left.nullable && right.nullable
+ override def nullable: Boolean = left.nullable && right.nullable
- override def children = left :: right :: Nil
+ override def children: Seq[Expression] = left :: right :: Nil
override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType
- override def dataType = {
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
@@ -366,7 +366,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
}
}
- override def toString = s"MaxOf($left, $right)"
+ override def toString: String = s"MaxOf($left, $right)"
}
/**
@@ -375,10 +375,10 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
case class Abs(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def dataType = child.dataType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"Abs($child)"
+ override def dataType: DataType = child.dataType
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"Abs($child)"
lazy val numeric = dataType match {
case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e48b8cde20eda..d1abf3c0b64a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -91,7 +91,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val startTime = System.nanoTime()
val result = create(in)
val endTime = System.nanoTime()
- def timeMs = (endTime - startTime).toDouble / 1000000
+ def timeMs: Double = (endTime - startTime).toDouble / 1000000
logInfo(s"Code generated expression $in in $timeMs ms")
result
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 68051a2a2007e..3fd78db297462 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -27,12 +27,12 @@ import org.apache.spark.sql.types._
case class GetItem(child: Expression, ordinal: Expression) extends Expression {
type EvaluatedType = Any
- val children = child :: ordinal :: Nil
+ val children: Seq[Expression] = child :: ordinal :: Nil
/** `Null` is returned for invalid ordinals. */
- override def nullable = true
- override def foldable = child.foldable && ordinal.foldable
+ override def nullable: Boolean = true
+ override def foldable: Boolean = child.foldable && ordinal.foldable
- def dataType = child.dataType match {
+ override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case MapType(_, vt, _) => vt
}
@@ -40,7 +40,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
childrenResolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
- override def toString = s"$child[$ordinal]"
+ override def toString: String = s"$child[$ordinal]"
override def eval(input: Row): Any = {
val value = child.eval(input)
@@ -75,8 +75,8 @@ trait GetField extends UnaryExpression {
self: Product =>
type EvaluatedType = Any
- override def foldable = child.foldable
- override def toString = s"$child.${field.name}"
+ override def foldable: Boolean = child.foldable
+ override def toString: String = s"$child.${field.name}"
def field: StructField
}
@@ -86,8 +86,8 @@ trait GetField extends UnaryExpression {
*/
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
- def dataType = field.dataType
- override def nullable = child.nullable || field.nullable
+ override def dataType: DataType = field.dataType
+ override def nullable: Boolean = child.nullable || field.nullable
override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
@@ -101,8 +101,8 @@ case class StructGetField(child: Expression, field: StructField, ordinal: Int) e
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
extends GetField {
- def dataType = ArrayType(field.dataType, containsNull)
- override def nullable = child.nullable
+ override def dataType: DataType = ArrayType(field.dataType, containsNull)
+ override def nullable: Boolean = child.nullable
override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
@@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co
case class CreateArray(children: Seq[Expression]) extends Expression {
override type EvaluatedType = Any
- override def foldable = !children.exists(!_.foldable)
+ override def foldable: Boolean = !children.exists(!_.foldable)
lazy val childTypes = children.map(_.dataType).distinct
@@ -140,5 +140,5 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
children.map(_.eval(input))
}
- override def toString = s"Array(${children.mkString(",")})"
+ override def toString: String = s"Array(${children.mkString(",")})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 83d8c1d42bca4..adb94df7d1c7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
override type EvaluatedType = Any
override def dataType: DataType = LongType
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"UnscaledValue($child)"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"UnscaledValue($child)"
override def eval(input: Row): Any = {
val childResult = child.eval(input)
@@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override type EvaluatedType = Decimal
override def dataType: DataType = DecimalType(precision, scale)
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"MakeDecimal($child,$precision,$scale)"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"MakeDecimal($child,$precision,$scale)"
override def eval(input: Row): Decimal = {
val childResult = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 0983d274def3f..860b72fad38b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -45,7 +45,7 @@ abstract class Generator extends Expression {
override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
- override def nullable = false
+ override def nullable: Boolean = false
/**
* Should be overridden by specific generators. Called only once for each instance to ensure
@@ -89,7 +89,7 @@ case class UserDefinedGenerator(
function(inputRow(input))
}
- override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
+ override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
}
/**
@@ -130,5 +130,5 @@ case class Explode(attributeNames: Seq[String], child: Expression)
}
}
- override def toString() = s"explode($child)"
+ override def toString: String = s"explode($child)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 9ff66563c8164..19f3fc9c2291a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -64,14 +64,13 @@ object IntegerLiteral {
case class Literal(value: Any, dataType: DataType) extends LeafExpression {
- override def foldable = true
- def nullable = value == null
+ override def foldable: Boolean = true
+ override def nullable: Boolean = value == null
-
- override def toString = if (value != null) value.toString else "null"
+ override def toString: String = if (value != null) value.toString else "null"
type EvaluatedType = Any
- override def eval(input: Row):Any = value
+ override def eval(input: Row): Any = value
}
// TODO: Specialize
@@ -79,9 +78,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean
extends LeafExpression {
type EvaluatedType = Any
- def update(expression: Expression, input: Row) = {
+ def update(expression: Expression, input: Row): Unit = {
value = expression.eval(input)
}
- override def eval(input: Row) = value
+ override def eval(input: Row): Any = value
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 17f7f9fe51376..bcbcbeb31c7b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.trees.LeafNode
import org.apache.spark.sql.types._
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
- def newExprId = ExprId(curId.getAndIncrement())
+ def newExprId: ExprId = ExprId(curId.getAndIncrement())
def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType)
}
@@ -41,6 +42,13 @@ abstract class NamedExpression extends Expression {
def name: String
def exprId: ExprId
+ /**
+ * Returns a dot separated fully qualified name for this attribute. Given that there can be
+ * multiple qualifiers, it is possible that there are other possible way to refer to this
+ * attribute.
+ */
+ def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".")
+
/**
* All possible qualifiers for the expression.
*
@@ -72,13 +80,13 @@ abstract class NamedExpression extends Expression {
abstract class Attribute extends NamedExpression {
self: Product =>
- override def references = AttributeSet(this)
+ override def references: AttributeSet = AttributeSet(this)
def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute
def withName(newName: String): Attribute
- def toAttribute = this
+ def toAttribute: Attribute = this
def newInstance(): Attribute
}
@@ -95,25 +103,30 @@ abstract class Attribute extends NamedExpression {
* @param name the name to be associated with the result of computing [[child]].
* @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
* alias. Auto-assigned if left blank.
+ * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
*/
-case class Alias(child: Expression, name: String)
- (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
+case class Alias(child: Expression, name: String)(
+ val exprId: ExprId = NamedExpression.newExprId,
+ val qualifiers: Seq[String] = Nil,
+ val explicitMetadata: Option[Metadata] = None)
extends NamedExpression with trees.UnaryNode[Expression] {
override type EvaluatedType = Any
- override def eval(input: Row) = child.eval(input)
+ override def eval(input: Row): Any = child.eval(input)
- override def dataType = child.dataType
- override def nullable = child.nullable
+ override def dataType: DataType = child.dataType
+ override def nullable: Boolean = child.nullable
override def metadata: Metadata = {
- child match {
- case named: NamedExpression => named.metadata
- case _ => Metadata.empty
+ explicitMetadata.getOrElse {
+ child match {
+ case named: NamedExpression => named.metadata
+ case _ => Metadata.empty
+ }
}
}
- override def toAttribute = {
+ override def toAttribute: Attribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
} else {
@@ -123,11 +136,14 @@ case class Alias(child: Expression, name: String)
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
- override protected final def otherCopyArgs = exprId :: qualifiers :: Nil
+ override protected final def otherCopyArgs: Seq[AnyRef] = {
+ exprId :: qualifiers :: explicitMetadata :: Nil
+ }
override def equals(other: Any): Boolean = other match {
case a: Alias =>
- name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers
+ name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers &&
+ explicitMetadata == a.explicitMetadata
case _ => false
}
}
@@ -153,7 +169,7 @@ case class AttributeReference(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
- override def equals(other: Any) = other match {
+ override def equals(other: Any): Boolean = other match {
case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType
case _ => false
}
@@ -167,7 +183,7 @@ case class AttributeReference(
h
}
- override def newInstance() =
+ override def newInstance(): AttributeReference =
AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers)
/**
@@ -192,7 +208,7 @@ case class AttributeReference(
/**
* Returns a copy of this [[AttributeReference]] with new qualifiers.
*/
- override def withQualifiers(newQualifiers: Seq[String]) = {
+ override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = {
if (newQualifiers.toSet == qualifiers.toSet) {
this
} else {
@@ -214,20 +230,22 @@ case class AttributeReference(
case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
type EvaluatedType = Any
- override def toString = name
-
- override def withNullability(newNullability: Boolean): Attribute = ???
- override def newInstance(): Attribute = ???
- override def withQualifiers(newQualifiers: Seq[String]): Attribute = ???
- override def withName(newName: String): Attribute = ???
- override def qualifiers: Seq[String] = ???
- override def exprId: ExprId = ???
- override def eval(input: Row): EvaluatedType = ???
- override def nullable: Boolean = ???
+ override def toString: String = name
+
+ override def withNullability(newNullability: Boolean): Attribute =
+ throw new UnsupportedOperationException
+ override def newInstance(): Attribute = throw new UnsupportedOperationException
+ override def withQualifiers(newQualifiers: Seq[String]): Attribute =
+ throw new UnsupportedOperationException
+ override def withName(newName: String): Attribute = throw new UnsupportedOperationException
+ override def qualifiers: Seq[String] = throw new UnsupportedOperationException
+ override def exprId: ExprId = throw new UnsupportedOperationException
+ override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException
+ override def nullable: Boolean = throw new UnsupportedOperationException
override def dataType: DataType = NullType
}
object VirtualColumn {
- val groupingIdName = "grouping__id"
- def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)()
+ val groupingIdName: String = "grouping__id"
+ def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 08b982bc671e7..d1f3d4f4ee9ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.types.DataType
case class Coalesce(children: Seq[Expression]) extends Expression {
type EvaluatedType = Any
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
- def nullable = !children.exists(!_.nullable)
+ override def nullable: Boolean = !children.exists(!_.nullable)
// Coalesce is foldable if all children are foldable.
- override def foldable = !children.exists(!_.foldable)
+ override def foldable: Boolean = !children.exists(!_.foldable)
// Only resolved if all the children are of the same type.
override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
- override def toString = s"Coalesce(${children.mkString(",")})"
+ override def toString: String = s"Coalesce(${children.mkString(",")})"
- def dataType = if (resolved) {
+ def dataType: DataType = if (resolved) {
children.head.dataType
} else {
val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
@@ -54,20 +55,20 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- override def foldable = child.foldable
- def nullable = false
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = false
override def eval(input: Row): Any = {
child.eval(input) == null
}
- override def toString = s"IS NULL $child"
+ override def toString: String = s"IS NULL $child"
}
case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
- override def foldable = child.foldable
- def nullable = false
- override def toString = s"IS NOT NULL $child"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = false
+ override def toString: String = s"IS NOT NULL $child"
override def eval(input: Row): Any = {
child.eval(input) != null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 0024ef92c0452..7e47cb3fffe12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.{BinaryType, BooleanType, NativeType}
+import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType}
object InterpretedPredicate {
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
@@ -34,7 +34,7 @@ object InterpretedPredicate {
trait Predicate extends Expression {
self: Product =>
- def dataType = BooleanType
+ override def dataType: DataType = BooleanType
type EvaluatedType = Any
}
@@ -72,13 +72,13 @@ trait PredicateHelper {
abstract class BinaryPredicate extends BinaryExpression with Predicate {
self: Product =>
- def nullable = left.nullable || right.nullable
+ override def nullable: Boolean = left.nullable || right.nullable
}
case class Not(child: Expression) extends UnaryExpression with Predicate {
- override def foldable = child.foldable
- def nullable = child.nullable
- override def toString = s"NOT $child"
+ override def foldable: Boolean = child.foldable
+ override def nullable: Boolean = child.nullable
+ override def toString: String = s"NOT $child"
override def eval(input: Row): Any = {
child.eval(input) match {
@@ -92,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate {
* Evaluates to `true` if `list` contains `value`.
*/
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
- def children = value +: list
+ override def children: Seq[Expression] = value +: list
- def nullable = true // TODO: Figure out correct nullability semantics of IN.
- override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
+ override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
override def eval(input: Row): Any = {
val evaluatedValue = value.eval(input)
@@ -110,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
case class InSet(value: Expression, hset: Set[Any])
extends Predicate {
- def children = value :: Nil
+ override def children: Seq[Expression] = value :: Nil
- def nullable = true // TODO: Figure out correct nullability semantics of IN.
- override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
+ override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}"
override def eval(input: Row): Any = {
hset.contains(value.eval(input))
@@ -121,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any])
}
case class And(left: Expression, right: Expression) extends BinaryPredicate {
- def symbol = "&&"
+ override def symbol: String = "&&"
override def eval(input: Row): Any = {
val l = left.eval(input)
@@ -143,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
}
case class Or(left: Expression, right: Expression) extends BinaryPredicate {
- def symbol = "||"
+ override def symbol: String = "||"
override def eval(input: Row): Any = {
val l = left.eval(input)
@@ -169,7 +169,8 @@ abstract class BinaryComparison extends BinaryPredicate {
}
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "="
+ override def symbol: String = "="
+
override def eval(input: Row): Any = {
val l = left.eval(input)
if (l == null) {
@@ -185,8 +186,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
}
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "<=>"
- override def nullable = false
+ override def symbol: String = "<=>"
+
+ override def nullable: Boolean = false
+
override def eval(input: Row): Any = {
val l = left.eval(input)
val r = right.eval(input)
@@ -201,9 +204,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "<"
+ override def symbol: String = "<"
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -216,7 +219,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
- if(evalE1 == null) {
+ if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
@@ -230,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = "<="
+ override def symbol: String = "<="
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -245,7 +248,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
- if(evalE1 == null) {
+ if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
@@ -259,9 +262,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = ">"
+ override def symbol: String = ">"
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -288,9 +291,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
- def symbol = ">="
+ override def symbol: String = ">="
- lazy val ordering = {
+ lazy val ordering: Ordering[Any] = {
if (left.dataType != right.dataType) {
throw new TreeNodeException(this,
s"Types do not match ${left.dataType} != ${right.dataType}")
@@ -303,7 +306,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
- if(evalE1 == null) {
+ if (evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
@@ -317,13 +320,13 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
- extends Expression {
+ extends Expression {
- def children = predicate :: trueValue :: falseValue :: Nil
- override def nullable = trueValue.nullable || falseValue.nullable
+ override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
+ override def nullable: Boolean = trueValue.nullable || falseValue.nullable
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
- def dataType = {
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(
this,
@@ -342,7 +345,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
}
- override def toString = s"if ($predicate) $trueValue else $falseValue"
+ override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}
// scalastyle:off
@@ -362,9 +365,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
// scalastyle:on
case class CaseWhen(branches: Seq[Expression]) extends Expression {
type EvaluatedType = Any
- def children = branches
- def dataType = {
+ override def children: Seq[Expression] = branches
+
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
}
@@ -379,12 +383,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
@transient private[this] lazy val elseValue =
if (branches.length % 2 == 0) None else Option(branches.last)
- override def nullable = {
+ override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
- override lazy val resolved = {
+ override lazy val resolved: Boolean = {
if (!childrenResolved) {
false
} else {
@@ -415,7 +419,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
res
}
- override def toString = {
+ override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index f03d6f71a9fae..8bba26bc4cf7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -44,8 +44,8 @@ trait MutableRow extends Row {
*/
object EmptyRow extends Row {
override def apply(i: Int): Any = throw new UnsupportedOperationException
- override def toSeq = Seq.empty
- override def length = 0
+ override def toSeq: Seq[Any] = Seq.empty
+ override def length: Int = 0
override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException
override def getInt(i: Int): Int = throw new UnsupportedOperationException
override def getLong(i: Int): Long = throw new UnsupportedOperationException
@@ -56,7 +56,7 @@ object EmptyRow extends Row {
override def getByte(i: Int): Byte = throw new UnsupportedOperationException
override def getString(i: Int): String = throw new UnsupportedOperationException
override def getAs[T](i: Int): T = throw new UnsupportedOperationException
- def copy() = this
+ override def copy(): Row = this
}
/**
@@ -70,13 +70,13 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def this(size: Int) = this(new Array[Any](size))
- override def toSeq = values.toSeq
+ override def toSeq: Seq[Any] = values.toSeq
- override def length = values.length
+ override def length: Int = values.length
- override def apply(i: Int) = values(i)
+ override def apply(i: Int): Any = values(i)
- override def isNullAt(i: Int) = values(i) == null
+ override def isNullAt(i: Int): Boolean = values(i) == null
override def getInt(i: Int): Int = {
if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
@@ -167,7 +167,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
case _ => false
}
- def copy() = this
+ override def copy(): Row = this
}
class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
@@ -194,7 +194,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
- override def copy() = new GenericRow(values.clone())
+ override def copy(): Row = new GenericRow(values.clone())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 3a5bdca1f07c3..35faa00782e80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -26,17 +26,17 @@ import org.apache.spark.util.collection.OpenHashSet
case class NewSet(elementType: DataType) extends LeafExpression {
type EvaluatedType = Any
- def nullable = false
+ override def nullable: Boolean = false
// We are currently only using these Expressions internally for aggregation. However, if we ever
// expose these to users we'll want to create a proper type instead of hijacking ArrayType.
- def dataType = ArrayType(elementType)
+ override def dataType: DataType = ArrayType(elementType)
- def eval(input: Row): Any = {
+ override def eval(input: Row): Any = {
new OpenHashSet[Any]()
}
- override def toString = s"new Set($dataType)"
+ override def toString: String = s"new Set($dataType)"
}
/**
@@ -46,12 +46,13 @@ case class NewSet(elementType: DataType) extends LeafExpression {
case class AddItemToSet(item: Expression, set: Expression) extends Expression {
type EvaluatedType = Any
- def children = item :: set :: Nil
+ override def children: Seq[Expression] = item :: set :: Nil
- def nullable = set.nullable
+ override def nullable: Boolean = set.nullable
- def dataType = set.dataType
- def eval(input: Row): Any = {
+ override def dataType: DataType = set.dataType
+
+ override def eval(input: Row): Any = {
val itemEval = item.eval(input)
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
@@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
}
}
- override def toString = s"$set += $item"
+ override def toString: String = s"$set += $item"
}
/**
@@ -77,13 +78,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
type EvaluatedType = Any
- def nullable = left.nullable || right.nullable
+ override def nullable: Boolean = left.nullable || right.nullable
- def dataType = left.dataType
+ override def dataType: DataType = left.dataType
- def symbol = "++="
+ override def symbol: String = "++="
- def eval(input: Row): Any = {
+ override def eval(input: Row): Any = {
val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
if(leftEval != null) {
val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
@@ -109,16 +110,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
case class CountSet(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
- def nullable = child.nullable
+ override def nullable: Boolean = child.nullable
- def dataType = LongType
+ override def dataType: DataType = LongType
- def eval(input: Row): Any = {
+ override def eval(input: Row): Any = {
val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]]
if (childEval != null) {
childEval.size.toLong
}
}
- override def toString = s"$child.count()"
+ override def toString: String = s"$child.count()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index f85ee0a9bb6d8..3cdca4e9dd2d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -33,8 +33,8 @@ trait StringRegexExpression {
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean
- def nullable: Boolean = left.nullable || right.nullable
- def dataType: DataType = BooleanType
+ override def nullable: Boolean = left.nullable || right.nullable
+ override def dataType: DataType = BooleanType
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
@@ -98,11 +98,11 @@ trait CaseConversionExpression {
case class Like(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
- def symbol = "LIKE"
+ override def symbol: String = "LIKE"
// replace the _ with .{1} exactly match 1 time of any character
// replace the % with .*, match 0 or more times with any character
- override def escape(v: String) =
+ override def escape(v: String): String =
if (!v.isEmpty) {
"(?s)" + (' ' +: v.init).zip(v).flatMap {
case (prev, '\\') => ""
@@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression)
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
- def symbol = "RLIKE"
+ override def symbol: String = "RLIKE"
override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
@@ -141,7 +141,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
override def convert(v: String): String = v.toUpperCase()
- override def toString() = s"Upper($child)"
+ override def toString: String = s"Upper($child)"
}
/**
@@ -151,7 +151,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
override def convert(v: String): String = v.toLowerCase()
- override def toString() = s"Lower($child)"
+ override def toString: String = s"Lower($child)"
}
/** A base trait for functions that compare two strings, returning a boolean. */
@@ -160,7 +160,7 @@ trait StringComparison {
type EvaluatedType = Any
- def nullable: Boolean = left.nullable || right.nullable
+ override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
def compare(l: String, r: String): Boolean
@@ -175,9 +175,9 @@ trait StringComparison {
}
}
- def symbol: String = nodeName
+ override def symbol: String = nodeName
- override def toString() = s"$nodeName($left, $right)"
+ override def toString: String = s"$nodeName($left, $right)"
}
/**
@@ -185,7 +185,7 @@ trait StringComparison {
*/
case class Contains(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
- override def compare(l: String, r: String) = l.contains(r)
+ override def compare(l: String, r: String): Boolean = l.contains(r)
}
/**
@@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression)
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
- def compare(l: String, r: String) = l.startsWith(r)
+ override def compare(l: String, r: String): Boolean = l.startsWith(r)
}
/**
@@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression)
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
- def compare(l: String, r: String) = l.endsWith(r)
+ override def compare(l: String, r: String): Boolean = l.endsWith(r)
}
/**
@@ -212,17 +212,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
type EvaluatedType = Any
- override def foldable = str.foldable && pos.foldable && len.foldable
+ override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
- def nullable: Boolean = str.nullable || pos.nullable || len.nullable
- def dataType: DataType = {
+ override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
+ override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
}
if (str.dataType == BinaryType) str.dataType else StringType
}
- override def children = str :: pos :: len :: Nil
+ override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
@@ -267,7 +267,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
}
}
- override def toString = len match {
+ override def toString: String = len match {
+ // TODO: This is broken because max is not an integer value.
case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)"
case _ => s"SUBSTR($str, $pos, $len)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1a75fcf3545bd..c23d3b61887c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan]
object DefaultOptimizer extends Optimizer {
val batches =
+ // SubQueries are only needed for analysis and can be removed before execution.
+ Batch("Remove SubQueries", FixedPoint(100),
+ EliminateSubQueries) ::
Batch("Combine Limits", FixedPoint(100),
CombineLimits) ::
Batch("ConstantFolding", FixedPoint(100),
@@ -137,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
/** Applies a projection only when the child is producing unnecessary attributes */
- def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences)
+ def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index b4c445b3badf1..9c8c643f7d17a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper {
(None, Nil, other, Map.empty)
}
- def collectAliases(fields: Seq[Expression]) = fields.collect {
+ def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child
}.toMap
- def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform {
- case a @ Alias(ref: AttributeReference, name) =>
- aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
+ def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
+ expr.transform {
+ case a @ Alias(ref: AttributeReference, name) =>
+ aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
- case a: AttributeReference =>
- aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a)
+ case a: AttributeReference =>
+ aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 400a6b2825c10..02f7c26a8ab6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -47,9 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
* Attributes that are referenced by expressions but not provided by this nodes children.
* Subclasses should override this method if they produce attributes internally as it is used by
* assertions designed to prevent the construction of invalid plans.
+ *
+ * Note that virtual columns should be excluded. Currently, we only support the grouping ID
+ * virtual column.
*/
- def missingInput: AttributeSet = (references -- inputSet)
- .filter(_.name != VirtualColumn.groupingIdName)
+ def missingInput: AttributeSet =
+ (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName)
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
@@ -68,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
var changed = false
- @inline def transformExpressionDown(e: Expression) = {
+ @inline def transformExpressionDown(e: Expression): Expression = {
val newE = e.transformDown(rule)
if (newE.fastEquals(e)) {
e
@@ -82,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e)
case other => other
@@ -100,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
var changed = false
- @inline def transformExpressionUp(e: Expression) = {
+ @inline def transformExpressionUp(e: Expression): Expression = {
val newE = e.transformUp(rule)
if (newE.fastEquals(e)) {
e
@@ -114,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e)
case other => other
@@ -160,5 +165,5 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
*/
protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else ""
- override def simpleString = statePrefix + super.simpleString
+ override def simpleString: String = statePrefix + super.simpleString
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8c4f09b58a4f2..b01a61d7bf8d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -73,12 +73,16 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* can do better should override this function.
*/
def sameResult(plan: LogicalPlan): Boolean = {
- plan.getClass == this.getClass &&
- plan.children.size == children.size && {
- logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]")
- cleanArgs == plan.cleanArgs
+ val cleanLeft = EliminateSubQueries(this)
+ val cleanRight = EliminateSubQueries(plan)
+
+ cleanLeft.getClass == cleanRight.getClass &&
+ cleanLeft.children.size == cleanRight.children.size && {
+ logDebug(
+ s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]")
+ cleanRight.cleanArgs == cleanLeft.cleanArgs
} &&
- (plan.children, children).zipped.forall(_ sameResult _)
+ (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _)
}
/** Args that have cleaned such that differences in expression id should not affect equality */
@@ -208,8 +212,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// More than one match.
case ambiguousReferences =>
+ val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
throw new AnalysisException(
- s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
+ s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 384fe53a68362..4d9e41a2b5d85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
- def output = projectList.map(_.toAttribute)
+ override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override lazy val resolved: Boolean = {
val containsAggregatesOrGenerators = projectList.exists ( _.collect {
@@ -66,19 +66,19 @@ case class Generate(
}
}
- override def output =
+ override def output: Seq[Attribute] =
if (join) child.output ++ generatorOutput else generatorOutput
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
// TODO: These aren't really the same attributes as nullability etc might change.
- override def output = left.output
+ override def output: Seq[Attribute] = left.output
- override lazy val resolved =
+ override lazy val resolved: Boolean =
childrenResolved &&
!left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType }
@@ -94,7 +94,7 @@ case class Join(
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
- override def output = {
+ override def output: Seq[Attribute] = {
joinType match {
case LeftSemi =>
left.output
@@ -109,7 +109,7 @@ case class Join(
}
}
- def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty
+ private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
// Joins are only resolved if they don't introduce ambiguious expression ids.
override lazy val resolved: Boolean = {
@@ -118,7 +118,7 @@ case class Join(
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- def output = left.output
+ override def output: Seq[Attribute] = left.output
}
case class InsertIntoTable(
@@ -128,10 +128,10 @@ case class InsertIntoTable(
overwrite: Boolean)
extends LogicalPlan {
- override def children = child :: Nil
- override def output = child.output
+ override def children: Seq[LogicalPlan] = child :: Nil
+ override def output: Seq[Attribute] = child.output
- override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
+ override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall {
case (childAttr, tableAttr) =>
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
@@ -143,14 +143,14 @@ case class CreateTableAsSelect[T](
child: LogicalPlan,
allowExisting: Boolean,
desc: Option[T] = None) extends UnaryNode {
- override def output = Seq.empty[Attribute]
- override lazy val resolved = databaseName != None && childrenResolved
+ override def output: Seq[Attribute] = Seq.empty[Attribute]
+ override lazy val resolved: Boolean = databaseName != None && childrenResolved
}
case class WriteToFile(
path: String,
child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
/**
@@ -163,7 +163,7 @@ case class Sort(
order: Seq[SortOrder],
global: Boolean,
child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class Aggregate(
@@ -172,7 +172,7 @@ case class Aggregate(
child: LogicalPlan)
extends UnaryNode {
- override def output = aggregateExpressions.map(_.toAttribute)
+ override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
}
/**
@@ -199,7 +199,7 @@ trait GroupingAnalytics extends UnaryNode {
def groupByExprs: Seq[Expression]
def aggregations: Seq[NamedExpression]
- override def output = aggregations.map(_.toAttribute)
+ override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
}
/**
@@ -264,7 +264,7 @@ case class Rollup(
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
override lazy val statistics: Statistics = {
val limit = limitExpr.eval(null).asInstanceOf[Int]
@@ -274,21 +274,21 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
- override def output = child.output.map(_.withQualifiers(alias :: Nil))
+ override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class Distinct(child: LogicalPlan) extends UnaryNode {
- override def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case object NoRelation extends LeafNode {
- override def output = Nil
+ override def output: Seq[Attribute] = Nil
/**
* Computes [[Statistics]] for this plan. The default implementation assumes the output
@@ -301,5 +301,5 @@ case object NoRelation extends LeafNode {
}
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- override def output = left.output
+ override def output: Seq[Attribute] = left.output
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
index 72b0c5c8e7a26..e737418d9c3bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder}
/**
* Performs a physical redistribution of the data. Used when the consumer of the query
@@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
abstract class RedistributeData extends UnaryNode {
self: Product =>
- def output = child.output
+ override def output: Seq[Attribute] = child.output
}
case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan)
- extends RedistributeData {
-}
+ extends RedistributeData
case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan)
- extends RedistributeData {
-}
-
+ extends RedistributeData
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 3c3d7a3119064..288c11f69fe22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder}
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{DataType, IntegerType}
/**
* Specifies how tuples that share common expressions will be distributed when a query is executed
@@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
"a single partition.")
// TODO: This is not really valid...
- def clustering = ordering.map(_.child).toSet
+ def clustering: Set[Expression] = ordering.map(_.child).toSet
}
sealed trait Partitioning {
@@ -113,7 +113,7 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}
@@ -124,7 +124,7 @@ case object BroadcastPartitioning extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}
@@ -139,9 +139,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression
with Partitioning {
- override def children = expressions
- override def nullable = false
- override def dataType = IntegerType
+ override def children: Seq[Expression] = expressions
+ override def nullable: Boolean = false
+ override def dataType: DataType = IntegerType
private[this] lazy val clusteringSet = expressions.toSet
@@ -152,7 +152,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case h: HashPartitioning if h == this => true
case _ => false
@@ -178,9 +178,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
extends Expression
with Partitioning {
- override def children = ordering
- override def nullable = false
- override def dataType = IntegerType
+ override def children: Seq[SortOrder] = ordering
+ override def nullable: Boolean = false
+ override def dataType: DataType = IntegerType
private[this] lazy val clusteringSet = ordering.map(_.child).toSet
@@ -194,7 +194,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
- override def compatibleWith(other: Partitioning) = other match {
+ override def compatibleWith(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case r: RangePartitioning if r == this => true
case _ => false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f84ffe4e176cc..a2df51e598a2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.types.DataType
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
@@ -35,12 +36,12 @@ object CurrentOrigin {
override def initialValue: Origin = Origin()
}
- def get = value.get()
- def set(o: Origin) = value.set(o)
+ def get: Origin = value.get()
+ def set(o: Origin): Unit = value.set(o)
- def reset() = value.set(Origin())
+ def reset(): Unit = value.set(Origin())
- def setPosition(line: Int, start: Int) = {
+ def setPosition(line: Int, start: Int): Unit = {
value.set(
value.get.copy(line = Some(line), startPosition = Some(start)))
}
@@ -56,7 +57,7 @@ object CurrentOrigin {
abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
self: BaseType with Product =>
- val origin = CurrentOrigin.get
+ val origin: Origin = CurrentOrigin.get
/** Returns a Seq of the children of this node */
def children: Seq[BaseType]
@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+ val defaultCtor =
+ getClass.getConstructors
+ .find(_.getParameterTypes.size != 0)
+ .headOption
+ .getOrElse(sys.error(s"No valid constructor for $nodeName"))
+
try {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
- val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
@@ -320,18 +328,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} catch {
case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException(
- this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? "
- + s"Exception message: ${e.getMessage}.")
+ this,
+ s"""
+ |Failed to copy node.
+ |Is otherCopyArgs specified correctly for $nodeName.
+ |Exception message: ${e.getMessage}
+ |ctor: $defaultCtor?
+ |args: ${newArgs.mkString(", ")}
+ """.stripMargin)
}
}
/** Returns the name of this type of TreeNode. Defaults to the class name. */
- def nodeName = getClass.getSimpleName
+ def nodeName: String = getClass.getSimpleName
/**
* The arguments that should be included in the arg string. Defaults to the `productIterator`.
*/
- protected def stringArgs = productIterator
+ protected def stringArgs: Iterator[Any] = productIterator
/** Returns a string representing the arguments to this node, minus any children */
def argString: String = productIterator.flatMap {
@@ -343,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}.mkString(", ")
/** String representation of this node without any children */
- def simpleString = s"$nodeName $argString".trim
+ def simpleString: String = s"$nodeName $argString".trim
override def toString: String = treeString
/** Returns a string representation of the nodes in this tree */
- def treeString = generateTreeString(0, new StringBuilder).toString
+ def treeString: String = generateTreeString(0, new StringBuilder).toString
/**
* Returns a string representation of the nodes in this tree, where each operator is numbered.
* The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees.
*/
- def numberedTreeString =
+ def numberedTreeString: String =
treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n")
/**
@@ -406,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] {
def left: BaseType
def right: BaseType
- def children = Seq(left, right)
+ def children: Seq[BaseType] = Seq(left, right)
}
/**
* A [[TreeNode]] with no children.
*/
trait LeafNode[BaseType <: TreeNode[BaseType]] {
- def children = Nil
+ def children: Seq[BaseType] = Nil
}
/**
@@ -421,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] {
*/
trait UnaryNode[BaseType <: TreeNode[BaseType]] {
def child: BaseType
- def children = child :: Nil
+ def children: Seq[BaseType] = child :: Nil
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
index 79a8e06d4b4d4..ea6aa1850db4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -41,11 +41,11 @@ package object trees extends Logging {
* A [[TreeNode]] companion for reference equality for Hash based Collection.
*/
class TreeNodeRef(val obj: TreeNode[_]) {
- override def equals(o: Any) = o match {
+ override def equals(o: Any): Boolean = o match {
case that: TreeNodeRef => that.obj.eq(obj)
case _ => false
}
- override def hashCode = if (obj == null) 0 else obj.hashCode
+ override def hashCode: Int = if (obj == null) 0 else obj.hashCode
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index feed50f9a2a2d..c86214a2aa944 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -23,7 +23,7 @@ import org.apache.spark.util.Utils
package object util {
- def fileToString(file: File, encoding: String = "UTF-8") = {
+ def fileToString(file: File, encoding: String = "UTF-8"): String = {
val inStream = new FileInputStream(file)
val outStream = new ByteArrayOutputStream
try {
@@ -45,7 +45,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
- classLoader: ClassLoader = Utils.getSparkClassLoader) = {
+ classLoader: ClassLoader = Utils.getSparkClassLoader): String = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {
@@ -93,7 +93,7 @@ package object util {
new String(out.toByteArray)
}
- def stringOrNull(a: AnyRef) = if (a == null) null else a.toString
+ def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString
def benchmark[A](f: => A): A = {
val startTime = System.nanoTime()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index c1dd5aa913ddc..756cd36f05c8c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -32,9 +32,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseInsensitiveCatalog = new SimpleCatalog(false)
val caseSensitiveAnalyzer =
- new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
+ new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) {
+ override val extendedResolutionRules = EliminateSubQueries :: Nil
+ }
val caseInsensitiveAnalyzer =
- new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
+ new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) {
+ override val extendedResolutionRules = EliminateSubQueries :: Nil
+ }
val checkAnalysis = new CheckAnalysis
@@ -199,4 +203,22 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(4).dataType == DoubleType)
}
+
+ test("SPARK-6452 regression test") {
+ // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
+ val plan =
+ Aggregate(
+ Nil,
+ Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", StringType)(exprId = ExprId(2))))
+
+ assert(plan.resolved)
+
+ val message = intercept[AnalysisException] {
+ caseSensitiveAnalyze(plan)
+ }.getMessage
+
+ assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ec7d15f5bc4e7..3cd7adf8cab5e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
+import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
@@ -46,7 +47,7 @@ private[sql] object Column {
* @groupname Ungrouped Support functions for DataFrames.
*/
@Experimental
-class Column(protected[sql] val expr: Expression) {
+class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
@@ -109,7 +110,15 @@ class Column(protected[sql] val expr: Expression) {
*
* @group expr_ops
*/
- def === (other: Any): Column = EqualTo(expr, lit(other).expr)
+ def === (other: Any): Column = {
+ val right = lit(other).expr
+ if (this.expr == right) {
+ logWarning(
+ s"Constructing trivially true equals predicate, '${this.expr} = $right'. " +
+ "Perhaps you need to use aliases.")
+ }
+ EqualTo(expr, right)
+ }
/**
* Equality test.
@@ -594,6 +603,19 @@ class Column(protected[sql] val expr: Expression) {
*/
def as(alias: Symbol): Column = Alias(expr, alias.name)()
+ /**
+ * Gives the column an alias with metadata.
+ * {{{
+ * val metadata: Metadata = ...
+ * df.select($"colA".as("colB", metadata))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def as(alias: String, metadata: Metadata): Column = {
+ Alias(expr, alias)(explicitMetadata = Some(metadata))
+ }
+
/**
* Casts the column to a different data type.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dc9912b52dcab..e59cf9b9e037b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
+ val (dataType, _) = inferDataType(beanClass)
+ dataType.asInstanceOf[StructType].fields.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable)()
+ }
+ }
+
+ /**
+ * Infers the corresponding SQL data type of a Java class.
+ * @param clazz Java class
+ * @return (SQL data type, nullable)
+ */
+ private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
- val beanInfo = Introspector.getBeanInfo(beanClass)
-
- // Note: The ordering of elements may differ from when the schema is inferred in Scala.
- // This is because beanInfo.getPropertyDescriptors gives no guarantees about
- // element ordering.
- val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
- fields.map { property =>
- val (dataType, nullable) = property.getPropertyType match {
- case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
- case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
- case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
- case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
- case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
- case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
- case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
- case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
- case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
- case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
- case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
- case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
- case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
- case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
- case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
- case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
- case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
- case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
- case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
- }
- AttributeReference(property.getName, dataType, nullable)()
+ clazz match {
+ case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
+
+ case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+ case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
+ case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
+ case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
+ case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
+ case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
+ case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
+ case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
+
+ case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
+ case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
+ case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
+ case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
+ case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
+ case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
+ case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
+
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
+ case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
+
+ case c: Class[_] if c.isArray =>
+ val (dataType, nullable) = inferDataType(c.getComponentType)
+ (ArrayType(dataType, nullable), true)
+
+ case _ =>
+ val beanInfo = Introspector.getBeanInfo(clazz)
+ val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ val fields = properties.map { property =>
+ val (dataType, nullable) = inferDataType(property.getPropertyType)
+ new StructField(property.getName, dataType, nullable)
+ }
+ (new StructType(fields), true)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 20c9bc3e75542..1f5251a20376f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.util.MutablePair
+import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.util.collection.ExternalSorter
/**
@@ -194,7 +194,9 @@ case class ExternalSort(
val ordering = newOrdering(sortOrder, child.output)
val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r, null)))
- sorter.iterator.map(_._1)
+ val baseIterator = sorter.iterator.map(_._1)
+ // TODO(marmbrus): The complex type signature below thwarts inference for no reason.
+ CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index d2e807d3a69b6..eb46b46ca5bf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -21,7 +21,7 @@ import scala.language.existentials
import scala.language.implicitConversions
import org.apache.spark.Logging
-import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
+import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -204,19 +204,25 @@ private[sql] object ResolvedDataSource {
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
+ def className = clazz.getCanonicalName
val relation = userSpecifiedSchema match {
case Some(schema: StructType) => clazz.newInstance() match {
case dataSource: SchemaRelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
- sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
+ throw new AnalysisException(s"$className does not allow user-specified schemas.")
+ case _ =>
+ throw new AnalysisException(s"$className is not a RelationProvider.")
}
case None => clazz.newInstance() match {
case dataSource: RelationProvider =>
dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
- sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
+ throw new AnalysisException(
+ s"A schema needs to be specified when using $className.")
+ case _ =>
+ throw new AnalysisException(s"$className is not a RelationProvider.")
}
}
new ResolvedDataSource(clazz, relation)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 2d586f784ac5a..1ff2d5a190521 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,29 +17,39 @@
package test.org.apache.spark.sql;
+import java.io.Serializable;
+import java.util.Arrays;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
+import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
-import static org.apache.spark.sql.functions.*;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.functions.*;
public class JavaDataFrameSuite {
+ private transient JavaSparkContext jsc;
private transient SQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
+ jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
}
@After
public void tearDown() {
+ jsc = null;
context = null;
}
@@ -90,4 +100,33 @@ public void testShow() {
df.show();
df.show(1000);
}
+
+ public static class Bean implements Serializable {
+ private double a = 0.0;
+ private Integer[] b = new Integer[]{0, 1};
+
+ public double getA() {
+ return a;
+ }
+
+ public Integer[] getB() {
+ return b;
+ }
+ }
+
+ @Test
+ public void testCreateDataFrameFromJavaBeans() {
+ Bean bean = new Bean();
+ JavaRDD rdd = jsc.parallelize(Arrays.asList(bean));
+ DataFrame df = context.createDataFrame(rdd, Bean.class);
+ StructType schema = df.schema();
+ Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
+ schema.apply("a"));
+ Assert.assertEquals(
+ new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
+ schema.apply("b"));
+ Row first = df.select("a", "b").first();
+ Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
+ Assert.assertArrayEquals(bean.getB(), first.getAs(1));
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index a53ae97d6243a..bc8fae100db6a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.expressions.NamedExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Project, NoRelation}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
+import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
@@ -322,4 +320,15 @@ class ColumnExpressionSuite extends QueryTest {
assert('key.desc == 'key.desc)
assert('key.desc != 'key.asc)
}
+
+ test("alias with metadata") {
+ val metadata = new MetadataBuilder()
+ .putString("originName", "value")
+ .build()
+ val schema = testData
+ .select($"*", col("value").as("abc", metadata))
+ .schema
+ assert(schema("value").metadata === Metadata.empty)
+ assert(schema("abc").metadata === metadata)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ff441ef26f9c0..c30ed694a62f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -108,6 +108,13 @@ class DataFrameSuite extends QueryTest {
)
}
+ test("self join with aliases") {
+ val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str")
+ checkAnswer(
+ df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
val df2 =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index dd0948ad824be..e4dee87849fd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -34,7 +34,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
- val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed
+ val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
@@ -109,7 +109,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("multiple-key equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
- val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed
+ val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index be105c6e83594..d615542ab50a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
+
+ test("udf that is transformed") {
+ udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+ // 1 + 1 is constant folded causing a transformation.
+ assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index f04437c595bf6..968557c9c4686 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -19,12 +19,29 @@ package org.apache.spark.sql.hive
import java.io.{OutputStream, PrintStream}
+import scala.util.Try
+
+import org.scalatest.BeforeAndAfter
+
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.{AnalysisException, QueryTest}
-import scala.util.Try
-class ErrorPositionSuite extends QueryTest {
+class ErrorPositionSuite extends QueryTest with BeforeAndAfter {
+
+ before {
+ Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes")
+ }
+
+ positionTest("ambiguous attribute reference 1",
+ "SELECT a from dupAttributes", "a")
+
+ positionTest("ambiguous attribute reference 2",
+ "SELECT a, b from dupAttributes", "a")
+
+ positionTest("ambiguous attribute reference 3",
+ "SELECT b, a from dupAttributes", "a")
positionTest("unresolved attribute 1",
"SELECT x FROM src", "x")
@@ -127,6 +144,10 @@ class ErrorPositionSuite extends QueryTest {
val error = intercept[AnalysisException] {
quietly(sql(query))
}
+
+ assert(!error.getMessage.contains("Seq("))
+ assert(!error.getMessage.contains("List("))
+
val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect {
case (l, i) if l.contains(token) => (l, i + 1)
}.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index ff2e6ea9ea51d..e5ad0bf552073 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -579,7 +579,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row(3) :: Row(4) :: Nil
)
- table("test_parquet_ctas").queryExecution.analyzed match {
+ table("test_parquet_ctas").queryExecution.optimizedPlan match {
case LogicalRelation(p: ParquetRelation2) => // OK
case _ =>
fail(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index d891c4e8903d9..8a31bd03092d1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -292,7 +292,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase {
Seq(Row(1, "str1"))
)
- table("test_parquet_ctas").queryExecution.analyzed match {
+ table("test_parquet_ctas").queryExecution.optimizedPlan match {
case LogicalRelation(p: ParquetRelation2) => // OK
case _ =>
fail(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index db64e11e16304..f73b463d07779 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -67,12 +67,12 @@ object Checkpoint extends Logging {
val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r
/** Get the checkpoint file for the given checkpoint time */
- def checkpointFile(checkpointDir: String, checkpointTime: Time) = {
+ def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = {
new Path(checkpointDir, PREFIX + checkpointTime.milliseconds)
}
/** Get the checkpoint backup file for the given checkpoint time */
- def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = {
+ def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = {
new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk")
}
@@ -232,6 +232,8 @@ object CheckpointReader extends Logging {
def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] =
{
val checkpointPath = new Path(checkpointDir)
+
+ // TODO(rxin): Why is this a def?!
def fs = checkpointPath.getFileSystem(hadoopConf)
// Try to find the checkpoint files
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 0e285d6088ec1..175140481e5ae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -100,11 +100,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
}
- def getInputStreams() = this.synchronized { inputStreams.toArray }
+ def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray }
- def getOutputStreams() = this.synchronized { outputStreams.toArray }
+ def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray }
- def getReceiverInputStreams() = this.synchronized {
+ def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized {
inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]])
.map(_.asInstanceOf[ReceiverInputDStream[_]])
.toArray
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala
index a0d8fb5ab93ec..3249bb348981f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala
@@ -55,7 +55,6 @@ case class Duration (private val millis: Long) {
def div(that: Duration): Double = this / that
-
def isMultipleOf(that: Duration): Boolean =
(this.millis % that.millis == 0)
@@ -71,7 +70,7 @@ case class Duration (private val millis: Long) {
def milliseconds: Long = millis
- def prettyPrint = Utils.msDurationToString(millis)
+ def prettyPrint: String = Utils.msDurationToString(millis)
}
@@ -80,7 +79,7 @@ case class Duration (private val millis: Long) {
* a given number of milliseconds.
*/
object Milliseconds {
- def apply(milliseconds: Long) = new Duration(milliseconds)
+ def apply(milliseconds: Long): Duration = new Duration(milliseconds)
}
/**
@@ -88,7 +87,7 @@ object Milliseconds {
* a given number of seconds.
*/
object Seconds {
- def apply(seconds: Long) = new Duration(seconds * 1000)
+ def apply(seconds: Long): Duration = new Duration(seconds * 1000)
}
/**
@@ -96,7 +95,7 @@ object Seconds {
* a given number of minutes.
*/
object Minutes {
- def apply(minutes: Long) = new Duration(minutes * 60000)
+ def apply(minutes: Long): Duration = new Duration(minutes * 60000)
}
// Java-friendlier versions of the objects above.
@@ -107,16 +106,16 @@ object Durations {
/**
* @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds.
*/
- def milliseconds(milliseconds: Long) = Milliseconds(milliseconds)
+ def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds)
/**
* @return [[org.apache.spark.streaming.Duration]] representing given number of seconds.
*/
- def seconds(seconds: Long) = Seconds(seconds)
+ def seconds(seconds: Long): Duration = Seconds(seconds)
/**
* @return [[org.apache.spark.streaming.Duration]] representing given number of minutes.
*/
- def minutes(minutes: Long) = Minutes(minutes)
+ def minutes(minutes: Long): Duration = Minutes(minutes)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala
index ad4f3fdd14ad6..3f5be785e1b1a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala
@@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) {
this.endTime < that.endTime
}
- def <= (that: Interval) = (this < that || this == that)
+ def <= (that: Interval): Boolean = (this < that || this == that)
- def > (that: Interval) = !(this <= that)
+ def > (that: Interval): Boolean = !(this <= that)
- def >= (that: Interval) = !(this < that)
+ def >= (that: Interval): Boolean = !(this < that)
- override def toString = "[" + beginTime + ", " + endTime + "]"
+ override def toString: String = "[" + beginTime + ", " + endTime + "]"
}
private[streaming]
object Interval {
- def currentInterval(duration: Duration): Interval = {
+ def currentInterval(duration: Duration): Interval = {
val time = new Time(System.currentTimeMillis)
val intervalBegin = time.floor(duration)
new Interval(intervalBegin, intervalBegin + duration)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 543224d4b07bc..f57f295874645 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -188,7 +188,7 @@ class StreamingContext private[streaming] (
/**
* Return the associated Spark context
*/
- def sparkContext = sc
+ def sparkContext: SparkContext = sc
/**
* Set each DStreams in this context to remember RDDs it generated in the last given duration.
@@ -596,7 +596,8 @@ object StreamingContext extends Logging {
@deprecated("Replaced by implicit functions in the DStream companion object. This is " +
"kept here only for backward compatibility.", "1.3.0")
def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
- (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
+ (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null)
+ : PairDStreamFunctions[K, V] = {
DStream.toPairDStreamFunctions(stream)(kt, vt, ord)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index 2eabdd9387913..73030e15c5661 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -415,8 +415,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
implicit val cmv2: ClassTag[V2] = fakeClassTag
implicit val cmw: ClassTag[W] = fakeClassTag
- def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] =
+ def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = {
transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ }
dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index 7053f47ec69a2..4c28654ef6413 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -176,11 +176,11 @@ private[python] abstract class PythonDStream(
val func = new TransformFunction(pfunc)
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
- val asJavaDStream = JavaDStream.fromDStream(this)
+ val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this)
}
/**
@@ -212,7 +212,7 @@ private[python] class PythonTransformed2DStream(
val func = new TransformFunction(pfunc)
- override def dependencies = List(parent, parent2)
+ override def dependencies: List[DStream[_]] = List(parent, parent2)
override def slideDuration: Duration = parent.slideDuration
@@ -223,7 +223,7 @@ private[python] class PythonTransformed2DStream(
func(Some(rdd1), Some(rdd2), validTime)
}
- val asJavaDStream = JavaDStream.fromDStream(this)
+ val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this)
}
/**
@@ -260,12 +260,15 @@ private[python] class PythonReducedWindowedDStream(
extends PythonDStream(parent, preduceFunc) {
super.persist(StorageLevel.MEMORY_ONLY)
- override val mustCheckpoint = true
- val invReduceFunc = new TransformFunction(pinvReduceFunc)
+ override val mustCheckpoint: Boolean = true
+
+ val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc)
def windowDuration: Duration = _windowDuration
+
override def slideDuration: Duration = _slideDuration
+
override def parentRememberDuration: Duration = rememberDuration + windowDuration
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index b874f561c12eb..795c5aa6d585b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -104,7 +104,7 @@ abstract class DStream[T: ClassTag] (
private[streaming] def parentRememberDuration = rememberDuration
/** Return the StreamingContext associated with this DStream */
- def context = ssc
+ def context: StreamingContext = ssc
/* Set the creation call site */
private[streaming] val creationSite = DStream.getCreationSite()
@@ -619,14 +619,16 @@ abstract class DStream[T: ClassTag] (
* operator, so this DStream will be registered as an output stream and there materialized.
*/
def print(num: Int) {
- def foreachFunc = (rdd: RDD[T], time: Time) => {
- val firstNum = rdd.take(num + 1)
- println ("-------------------------------------------")
- println ("Time: " + time)
- println ("-------------------------------------------")
- firstNum.take(num).foreach(println)
- if (firstNum.size > num) println("...")
- println()
+ def foreachFunc: (RDD[T], Time) => Unit = {
+ (rdd: RDD[T], time: Time) => {
+ val firstNum = rdd.take(num + 1)
+ println("-------------------------------------------")
+ println("Time: " + time)
+ println("-------------------------------------------")
+ firstNum.take(num).foreach(println)
+ if (firstNum.size > num) println("...")
+ println()
+ }
}
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
index 0dc72790fbdbd..39fd21342813e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
@@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
}
}
- override def toString() = {
+ override def toString: String = {
"[\n" + currentCheckpointFiles.size + " checkpoint files \n" +
currentCheckpointFiles.mkString("\n") + "\n]"
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 22de8c02e63c8..66d519171fd76 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -298,7 +298,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
private[streaming]
class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
- def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]]
+ private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]]
override def update(time: Time) {
hadoopFiles.clear()
@@ -320,7 +320,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
}
}
- override def toString() = {
+ override def toString: String = {
"[\n" + hadoopFiles.size + " file sets\n" +
hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]"
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
index c81534ae584ea..fcd5216f101af 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
@@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag](
filterFunc: T => Boolean
) extends DStream[T](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
index 658623455498c..9d09a3baf37ca 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
@@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag](
flatMapValueFunc: V => TraversableOnce[U]
) extends DStream[(K, U)](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
index c7bb2833eabb8..475ea2d2d4f38 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
@@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag](
flatMapFunc: T => Traversable[U]
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
index 1361c30395b57..685a32e1d280d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
@@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] (
foreachFunc: (RDD[T], Time) => Unit
) extends DStream[Unit](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
index a9bb51f054048..dbb295fe54f71 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
@@ -25,7 +25,7 @@ private[streaming]
class GlommedDStream[T: ClassTag](parent: DStream[T])
extends DStream[Array[T]](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index aa1993f0580a8..e652702e213ef 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -61,7 +61,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
}
}
- override def dependencies = List()
+ override def dependencies: List[DStream[_]] = List()
override def slideDuration: Duration = {
if (ssc == null) throw new Exception("ssc is null")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
index 3d8ee29df1e82..5994bc1e23f2b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
@@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag](
preservePartitioning: Boolean
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
index 7aea1f945d9db..954d2eb4a7b00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
@@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag](
mapValueFunc: V => U
) extends DStream[(K, U)](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
index 02704a8d1c2e0..fa14b2e897c3e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
@@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] (
mapFunc: T => U
) extends DStream[U](parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index c0a5af0b65cc3..1385ccbf56ee5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -52,7 +52,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
// Reduce each batch of data using reduceByKey which will be further reduced by window
// by ReducedWindowedDStream
- val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+ private val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
// Persist RDDs to memory by default as these RDDs are going to be reused.
super.persist(StorageLevel.MEMORY_ONLY_SER)
@@ -60,7 +60,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
def windowDuration: Duration = _windowDuration
- override def dependencies = List(reducedStream)
+ override def dependencies: List[DStream[_]] = List(reducedStream)
override def slideDuration: Duration = _slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
index 880a89bc36895..7757ccac09a58 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
@@ -33,7 +33,7 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
mapSideCombine: Boolean = true
) extends DStream[(K,C)] (parent.ssc) {
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index ebb04dd35b9a2..de8718d0a80fe 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
super.persist(StorageLevel.MEMORY_ONLY_SER)
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = parent.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
index 71b61856e23c0..5d46ca0715ffd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
@@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] (
require(parents.map(_.slideDuration).distinct.size == 1,
"Some of the DStreams have different slide durations")
- override def dependencies = parents.toList
+ override def dependencies: List[DStream[_]] = parents.toList
override def slideDuration: Duration = parents.head.slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
index abbc40befa95b..9405dbaa12329 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
@@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
require(parents.map(_.slideDuration).distinct.size == 1,
"Some of the DStreams have different slide durations")
- override def dependencies = parents.toList
+ override def dependencies: List[DStream[_]] = parents.toList
override def slideDuration: Duration = parents.head.slideDuration
override def compute(validTime: Time): Option[RDD[T]] = {
val rdds = new ArrayBuffer[RDD[T]]()
- parents.map(_.getOrCompute(validTime)).foreach(_ match {
+ parents.map(_.getOrCompute(validTime)).foreach {
case Some(rdd) => rdds += rdd
case None => throw new Exception("Could not generate RDD from a parent for unifying at time "
+ validTime)
- })
+ }
if (rdds.size > 0) {
Some(new UnionRDD(ssc.sc, rdds))
} else {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index 775b6bfd065c0..899865a906c27 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -46,7 +46,7 @@ class WindowedDStream[T: ClassTag](
def windowDuration: Duration = _windowDuration
- override def dependencies = List(parent)
+ override def dependencies: List[DStream[_]] = List(parent)
override def slideDuration: Duration = _slideDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index dd1e96334952f..93caa4ba35c7f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -117,8 +117,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
override def getPreferredLocations(split: Partition): Seq[String] = {
val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition]
val blockLocations = getBlockIdLocations().get(partition.blockId)
- def segmentLocations = HdfsUtils.getFileSegmentLocations(
- partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)
- blockLocations.getOrElse(segmentLocations)
+ blockLocations.getOrElse(
+ HdfsUtils.getFileSegmentLocations(
+ partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
index a7d63bd4f2dbf..cd309788a7717 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala
@@ -17,6 +17,7 @@
package org.apache.spark.streaming.receiver
+import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.duration._
@@ -25,10 +26,10 @@ import scala.reflect.ClassTag
import akka.actor._
import akka.actor.SupervisorStrategy.{Escalate, Restart}
+
import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.storage.StorageLevel
-import java.nio.ByteBuffer
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.storage.StorageLevel
/**
* :: DeveloperApi ::
@@ -149,13 +150,13 @@ private[streaming] class ActorReceiver[T: ClassTag](
class Supervisor extends Actor {
override val supervisorStrategy = receiverSupervisorStrategy
- val worker = context.actorOf(props, name)
+ private val worker = context.actorOf(props, name)
logInfo("Started receiver worker at:" + worker.path)
- val n: AtomicInteger = new AtomicInteger(0)
- val hiccups: AtomicInteger = new AtomicInteger(0)
+ private val n: AtomicInteger = new AtomicInteger(0)
+ private val hiccups: AtomicInteger = new AtomicInteger(0)
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case IteratorData(iterator) =>
logDebug("received iterator")
@@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag](
}
}
- def onStart() = {
+ def onStart(): Unit = {
supervisor
logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
-
}
- def onStop() = {
+ def onStop(): Unit = {
supervisor ! PoisonPill
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index ee5e639b26d91..42514d8b47dcf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -120,7 +120,7 @@ private[streaming] class BlockGenerator(
* `BlockGeneratorListener.onAddData` callback will be called. All received data items
* will be periodically pushed into BlockManager.
*/
- def addDataWithCallback(data: Any, metadata: Any) = synchronized {
+ def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized {
waitToPush()
currentBuffer += data
listener.onAddData(data, metadata)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
index 5acf8a9a811ee..5b5a3fe648602 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
@@ -245,7 +245,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* Get the unique identifier the receiver input stream that this
* receiver is associated with.
*/
- def streamId = id
+ def streamId: Int = id
/*
* =================
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index 1f0244c251eba..4943f29395d12 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -162,13 +162,13 @@ private[streaming] abstract class ReceiverSupervisor(
}
/** Check if receiver has been marked for stopping */
- def isReceiverStarted() = {
+ def isReceiverStarted(): Boolean = {
logDebug("state = " + receiverState)
receiverState == Started
}
/** Check if receiver has been marked for stopping */
- def isReceiverStopped() = {
+ def isReceiverStopped(): Boolean = {
logDebug("state = " + receiverState)
receiverState == Stopped
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 7d29ed88cfcb4..8f2f1fef76874 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
-import akka.actor.{Actor, Props}
+import akka.actor.{ActorRef, Actor, Props}
import akka.pattern.ask
import com.google.common.base.Throwables
import org.apache.hadoop.conf.Configuration
@@ -83,7 +83,7 @@ private[streaming] class ReceiverSupervisorImpl(
private val actor = env.actorSystem.actorOf(
Props(new Actor {
- override def receive() = {
+ override def receive: PartialFunction[Any, Unit] = {
case StopReceiver =>
logInfo("Received stop signal")
stop("Stopped by driver", None)
@@ -92,7 +92,7 @@ private[streaming] class ReceiverSupervisorImpl(
cleanupOldBlocks(threshTime)
}
- def ref = self
+ def ref: ActorRef = self
}), "Receiver-" + streamId + "-" + System.currentTimeMillis())
/** Unique block ids if one wants to add blocks directly */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index 7e0f6b2cdfc08..30cf87f5b7dd1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -36,5 +36,5 @@ class Job(val time: Time, func: () => _) {
id = "streaming job " + time + "." + number
}
- override def toString = id
+ override def toString: String = id
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 59488dfb0f8c6..4946806d2ee95 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -82,7 +82,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
if (eventActor != null) return // generator has already been started
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case event: JobGeneratorEvent => processEvent(event)
}
}), "JobGenerator")
@@ -111,8 +111,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val pollTime = 100
// To prevent graceful stop to get stuck permanently
- def hasTimedOut = {
- val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout
+ def hasTimedOut: Boolean = {
+ val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout
if (timedOut) {
logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")")
}
@@ -133,7 +133,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
logInfo("Stopped generation timer")
// Wait for the jobs to complete and checkpoints to be written
- def haveAllBatchesBeenProcessed = {
+ def haveAllBatchesBeenProcessed: Boolean = {
lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime
}
logInfo("Waiting for jobs to be processed and checkpoints to be written")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 60bc099b27a4c..d6a93acbe711b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -56,7 +56,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
logDebug("Starting JobScheduler")
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case event: JobSchedulerEvent => processEvent(event)
}
}), "JobScheduler")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 8c15a75b1b0e0..5b134877d0b2d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -28,8 +28,7 @@ private[streaming]
case class JobSet(
time: Time,
jobs: Seq[Job],
- receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty
- ) {
+ receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) {
private val incompleteJobs = new HashSet[Job]()
private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
@@ -48,17 +47,17 @@ case class JobSet(
if (hasCompleted) processingEndTime = System.currentTimeMillis()
}
- def hasStarted = processingStartTime > 0
+ def hasStarted: Boolean = processingStartTime > 0
- def hasCompleted = incompleteJobs.isEmpty
+ def hasCompleted: Boolean = incompleteJobs.isEmpty
// Time taken to process all the jobs from the time they started processing
// (i.e. not including the time they wait in the streaming scheduler queue)
- def processingDelay = processingEndTime - processingStartTime
+ def processingDelay: Long = processingEndTime - processingStartTime
// Time taken to process all the jobs from the time they were submitted
// (i.e. including the time they wait in the streaming scheduler queue)
- def totalDelay = {
+ def totalDelay: Long = {
processingEndTime - time.milliseconds
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index b36aeb341d25e..98900473138fe 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -72,7 +72,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
private var actor: ActorRef = null
/** Start the actor and receiver execution thread. */
- def start() = synchronized {
+ def start(): Unit = synchronized {
if (actor != null) {
throw new SparkException("ReceiverTracker already started")
}
@@ -86,7 +86,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
/** Stop the receiver execution thread. */
- def stop(graceful: Boolean) = synchronized {
+ def stop(graceful: Boolean): Unit = synchronized {
if (!receiverInputStreams.isEmpty && actor != null) {
// First, stop the receivers
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
@@ -201,7 +201,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** Actor to receive messages from the receivers. */
private class ReceiverTrackerActor extends Actor {
- def receive = {
+ override def receive: PartialFunction[Any, Unit] = {
case RegisterReceiver(streamId, typ, host, receiverActor) =>
registerReceiver(streamId, typ, host, receiverActor, sender)
sender ! true
@@ -244,16 +244,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
if (graceful) {
val pollTime = 100
- def done = { receiverInfo.isEmpty && !running }
logInfo("Waiting for receiver job to terminate gracefully")
- while(!done) {
+ while (receiverInfo.nonEmpty || running) {
Thread.sleep(pollTime)
}
logInfo("Waited for receiver job to terminate gracefully")
}
// Check if all the receivers have been deregistered or not
- if (!receiverInfo.isEmpty) {
+ if (receiverInfo.nonEmpty) {
logWarning("Not all of the receivers have deregistered, " + receiverInfo)
} else {
logInfo("All of the receivers have deregistered successfully")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
index 5ee53a5c5f561..e4bd067cacb77 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
@@ -17,9 +17,10 @@
package org.apache.spark.streaming.ui
+import scala.collection.mutable.{Queue, HashMap}
+
import org.apache.spark.streaming.{Time, StreamingContext}
import org.apache.spark.streaming.scheduler._
-import scala.collection.mutable.{Queue, HashMap}
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted
import org.apache.spark.streaming.scheduler.BatchInfo
@@ -59,11 +60,13 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
}
}
- override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized {
- runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo
+ override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = {
+ synchronized {
+ runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo
+ }
}
- override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized {
+ override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized {
runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo
waitingBatchInfos.remove(batchStarted.batchInfo.batchTime)
@@ -72,19 +75,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
}
}
- override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized {
- waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime)
- runningBatchInfos.remove(batchCompleted.batchInfo.batchTime)
- completedaBatchInfos.enqueue(batchCompleted.batchInfo)
- if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue()
- totalCompletedBatches += 1L
-
- batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) =>
- totalProcessedRecords += infos.map(_.numRecords).sum
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
+ synchronized {
+ waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime)
+ runningBatchInfos.remove(batchCompleted.batchInfo.batchTime)
+ completedaBatchInfos.enqueue(batchCompleted.batchInfo)
+ if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue()
+ totalCompletedBatches += 1L
+
+ batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) =>
+ totalProcessedRecords += infos.map(_.numRecords).sum
+ }
}
}
- def numReceivers = synchronized {
+ def numReceivers: Int = synchronized {
ssc.graph.getReceiverInputStreams().size
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
index a73d6f3bf0661..4d968f8bfa7a8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
@@ -18,9 +18,7 @@
package org.apache.spark.streaming.util
import org.apache.spark.SparkContext
-import org.apache.spark.SparkContext._
import org.apache.spark.util.collection.OpenHashMap
-import scala.collection.JavaConversions.mapAsScalaMap
private[streaming]
object RawTextHelper {
@@ -71,7 +69,7 @@ object RawTextHelper {
var count = 0
while(data.hasNext) {
- value = data.next
+ value = data.next()
if (value != null) {
count += 1
if (len == 0) {
@@ -108,9 +106,13 @@ object RawTextHelper {
}
}
- def add(v1: Long, v2: Long) = (v1 + v2)
+ def add(v1: Long, v2: Long): Long = {
+ v1 + v2
+ }
- def subtract(v1: Long, v2: Long) = (v1 - v2)
+ def subtract(v1: Long, v2: Long): Long = {
+ v1 - v2
+ }
- def max(v1: Long, v2: Long) = math.max(v1, v2)
+ def max(v1: Long, v2: Long): Long = math.max(v1, v2)
}