diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index fc95781448569..ceb7916df0b34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -98,7 +98,7 @@ class SchemaRDD( def baseSchemaRDD = this // ========================================================================================= - // RDD functions: Copy the interal row representation so we present immutable data to users. + // RDD functions: Copy the internal row representation so we present immutable data to users. // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 450c142c0baa4..070557e47c4c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -61,7 +61,14 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.map(_._1) case SinglePartition => - child.execute().coalesce(1, shuffle = true) + val rdd = child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Null, Row]() + iter.map(r => mutablePair.update(null, r)) + } + val partitioner = new HashPartitioner(1) + val shuffled = new ShuffledRDD[Null, Row, MutablePair[Null, Row]](rdd, partitioner) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 2fea9702954d7..465e5f146fe71 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -160,12 +160,6 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), - TestTable("dest1", - "CREATE TABLE IF NOT EXISTS dest1 (key INT, value STRING)".cmd), - TestTable("dest2", - "CREATE TABLE IF NOT EXISTS dest2 (key INT, value STRING)".cmd), - TestTable("dest3", - "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd), TestTable("srcpart", () => { runSqlHive( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") @@ -257,6 +251,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { private val loadedTables = new collection.mutable.HashSet[String] + var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. @@ -265,6 +260,9 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) + + if (cacheTables) + cacheTable(name) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 3cc4562a88d66..9370c478bb589 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -218,10 +218,7 @@ abstract class HiveComparisonTest val quotes = "\"\"\"" queryList.zipWithIndex.map { case (query, i) => - s""" - |val q$i = $quotes$query$quotes.q - |q$i.stringResult() - """.stripMargin + s"""val q$i = hql($quotes$query$quotes); q$i.collect()""" }.mkString("\n== Console version of this test ==\n", "\n", "\n") } @@ -287,7 +284,6 @@ abstract class HiveComparisonTest |Error: ${e.getMessage} |${stackTraceToString(e)} |$queryString - |$consoleTestCase """.stripMargin stringToFile( new File(hiveFailedDirectory, testCaseName), @@ -313,8 +309,6 @@ abstract class HiveComparisonTest |$query |== HIVE - ${hive.size} row(s) == |${hive.mkString("\n")} - | - |$consoleTestCase """.stripMargin stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) fail(errorMessage) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f76e16bc1afc5..d11b119f40067 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -18,6 +18,17 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.TestHive +import org.scalatest.BeforeAndAfter + +class HiveInMemoryCompatibilitySuite extends HiveCompatibilitySuite with BeforeAndAfter { + override def beforeAll() { + TestHive.cacheTables = true + } + + override def afterAll() { + TestHive.cacheTables = false + } +} /** * Runs the test cases that are included in the hive distribution.