", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker)
- store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store = makeBlockManager(12000)
+ store.putSingle(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY)
// Access rdd_1_0 to ensure it's not least recently used.
assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store")
// According to the same-RDD rule, rdd_1_0 should be replaced here.
- store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY)
// rdd_1_0 should have been replaced, even it's not least recently used.
assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store")
assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store")
assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
}
+
+ test("reserve/release unroll memory") {
+ store = makeBlockManager(12000)
+ val memoryStore = store.memoryStore
+ assert(memoryStore.currentUnrollMemory === 0)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Reserve
+ memoryStore.reserveUnrollMemoryForThisThread(100)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 100)
+ memoryStore.reserveUnrollMemoryForThisThread(200)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 300)
+ memoryStore.reserveUnrollMemoryForThisThread(500)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 800)
+ memoryStore.reserveUnrollMemoryForThisThread(1000000)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted
+ // Release
+ memoryStore.releaseUnrollMemoryForThisThread(100)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 700)
+ memoryStore.releaseUnrollMemoryForThisThread(100)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 600)
+ // Reserve again
+ memoryStore.reserveUnrollMemoryForThisThread(4400)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 5000)
+ memoryStore.reserveUnrollMemoryForThisThread(20000)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted
+ // Release again
+ memoryStore.releaseUnrollMemoryForThisThread(1000)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 4000)
+ memoryStore.releaseUnrollMemoryForThisThread() // release all
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ }
+
+ /**
+ * Verify the result of MemoryStore#unrollSafely is as expected.
+ */
+ private def verifyUnroll(
+ expected: Iterator[Any],
+ result: Either[Array[Any], Iterator[Any]],
+ shouldBeArray: Boolean): Unit = {
+ val actual: Iterator[Any] = result match {
+ case Left(arr: Array[Any]) =>
+ assert(shouldBeArray, "expected iterator from unroll!")
+ arr.iterator
+ case Right(it: Iterator[Any]) =>
+ assert(!shouldBeArray, "expected array from unroll!")
+ it
+ case _ =>
+ fail("unroll returned neither an iterator nor an array...")
+ }
+ expected.zip(actual).foreach { case (e, a) =>
+ assert(e === a, "unroll did not return original values!")
+ }
+ }
+
+ test("safely unroll blocks") {
+ store = makeBlockManager(12000)
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ val memoryStore = store.memoryStore
+ val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll with all the space in the world. This should succeed and return an array.
+ var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
+ verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll with not enough space. This should succeed after kicking out someBlock1.
+ store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)
+ store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)
+ unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
+ verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ assert(droppedBlocks.size === 1)
+ assert(droppedBlocks.head._1 === TestBlockId("someBlock1"))
+ droppedBlocks.clear()
+
+ // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 =
+ // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
+ // In the mean time, however, we kicked out someBlock2 before giving up.
+ store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)
+ unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks)
+ verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false)
+ assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ assert(droppedBlocks.size === 1)
+ assert(droppedBlocks.head._1 === TestBlockId("someBlock2"))
+ droppedBlocks.clear()
+ }
+
+ test("safely unroll blocks through putIterator") {
+ store = makeBlockManager(12000)
+ val memOnly = StorageLevel.MEMORY_ONLY
+ val memoryStore = store.memoryStore
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]]
+ def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll with plenty of space. This should succeed and cache both blocks.
+ val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
+ val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true)
+ assert(memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(result1.size > 0) // unroll was successful
+ assert(result2.size > 0)
+ assert(result1.data.isLeft) // unroll did not drop this block to disk
+ assert(result2.data.isLeft)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Re-put these two blocks so block manager knows about them too. Otherwise, block manager
+ // would not know how to drop them from memory later.
+ memoryStore.remove("b1")
+ memoryStore.remove("b2")
+ store.putIterator("b1", smallIterator, memOnly)
+ store.putIterator("b2", smallIterator, memOnly)
+
+ // Unroll with not enough space. This should succeed but kick out b1 in the process.
+ val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true)
+ assert(result3.size > 0)
+ assert(result3.data.isLeft)
+ assert(!memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ memoryStore.remove("b3")
+ store.putIterator("b3", smallIterator, memOnly)
+
+ // Unroll huge block with not enough space. This should fail and kick out b2 in the process.
+ val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, returnValues = true)
+ assert(result4.size === 0) // unroll was unsuccessful
+ assert(result4.data.isLeft)
+ assert(!memoryStore.contains("b1"))
+ assert(!memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(!memoryStore.contains("b4"))
+ assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ }
+
+ /**
+ * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK.
+ */
+ test("safely unroll blocks through putIterator (disk)") {
+ store = makeBlockManager(12000)
+ val memAndDisk = StorageLevel.MEMORY_AND_DISK
+ val memoryStore = store.memoryStore
+ val diskStore = store.diskStore
+ val smallList = List.fill(40)(new Array[Byte](100))
+ val bigList = List.fill(40)(new Array[Byte](1000))
+ def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]]
+ def bigIterator = bigList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ store.putIterator("b1", smallIterator, memAndDisk)
+ store.putIterator("b2", smallIterator, memAndDisk)
+
+ // Unroll with not enough space. This should succeed but kick out b1 in the process.
+ // Memory store should contain b2 and b3, while disk store should contain only b1
+ val result3 = memoryStore.putIterator("b3", smallIterator, memAndDisk, returnValues = true)
+ assert(result3.size > 0)
+ assert(!memoryStore.contains("b1"))
+ assert(memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(diskStore.contains("b1"))
+ assert(!diskStore.contains("b2"))
+ assert(!diskStore.contains("b3"))
+ memoryStore.remove("b3")
+ store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll huge block with not enough space. This should fail and drop the new block to disk
+ // directly in addition to kicking out b2 in the process. Memory store should contain only
+ // b3, while disk store should contain b1, b2 and b4.
+ val result4 = memoryStore.putIterator("b4", bigIterator, memAndDisk, returnValues = true)
+ assert(result4.size > 0)
+ assert(result4.data.isRight) // unroll returned bytes from disk
+ assert(!memoryStore.contains("b1"))
+ assert(!memoryStore.contains("b2"))
+ assert(memoryStore.contains("b3"))
+ assert(!memoryStore.contains("b4"))
+ assert(diskStore.contains("b1"))
+ assert(diskStore.contains("b2"))
+ assert(!diskStore.contains("b3"))
+ assert(diskStore.contains("b4"))
+ assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+ }
+
+ test("multiple unrolls by the same thread") {
+ store = makeBlockManager(12000)
+ val memOnly = StorageLevel.MEMORY_ONLY
+ val memoryStore = store.memoryStore
+ val smallList = List.fill(40)(new Array[Byte](100))
+ def smallIterator = smallList.iterator.asInstanceOf[Iterator[Any]]
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // All unroll memory used is released because unrollSafely returned an array
+ memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+ memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true)
+ assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+
+ // Unroll memory is not released because unrollSafely returned an iterator
+ // that still depends on the underlying vector used in the process
+ memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread
+ assert(unrollMemoryAfterB3 > 0)
+
+ // The unroll memory owned by this thread builds on top of its value after the previous unrolls
+ memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread
+ assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
+
+ // ... but only to a certain extent (until we run out of free space to grant new unroll memory)
+ memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread
+ memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread
+ memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true)
+ val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread
+ assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
+ assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
+ assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
deleted file mode 100644
index 93f0c6a8e6408..0000000000000
--- a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-import scala.util.Random
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass
-import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap}
-
-class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll {
- val NORMAL_ERROR = 0.20
- val HIGH_ERROR = 0.30
-
- test("fixed size insertions") {
- testWith[Int, Long](10000, i => (i, i.toLong))
- testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
- testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass()))
- }
-
- test("variable size insertions") {
- val rand = new Random(123456789)
- def randString(minLen: Int, maxLen: Int): String = {
- "a" * (rand.nextInt(maxLen - minLen) + minLen)
- }
- testWith[Int, String](10000, i => (i, randString(0, 10)))
- testWith[Int, String](10000, i => (i, randString(0, 100)))
- testWith[Int, String](10000, i => (i, randString(90, 100)))
- }
-
- test("updates") {
- val rand = new Random(123456789)
- def randString(minLen: Int, maxLen: Int): String = {
- "a" * (rand.nextInt(maxLen - minLen) + minLen)
- }
- testWith[String, Int](10000, i => (randString(0, 10000), i))
- }
-
- def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
- val map = new SizeTrackingAppendOnlyMap[K, V]()
- for (i <- 0 until numElements) {
- val (k, v) = makeElement(i)
- map(k) = v
- expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
- }
- }
-
- def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
- val betterEstimatedSize = SizeEstimator.estimate(obj)
- assert(betterEstimatedSize * (1 - error) < estimatedSize,
- s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
- assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
- s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
- }
-}
-
-object SizeTrackingAppendOnlyMapSuite {
- // Speed test, for reproducibility of results.
- // These could be highly non-deterministic in general, however.
- // Results:
- // AppendOnlyMap: 31 ms
- // SizeTracker: 54 ms
- // SizeEstimator: 1500 ms
- def main(args: Array[String]) {
- val numElements = 100000
-
- val baseTimes = for (i <- 0 until 10) yield time {
- val map = new AppendOnlyMap[Int, LargeDummyClass]()
- for (i <- 0 until numElements) {
- map(i) = new LargeDummyClass()
- }
- }
-
- val sampledTimes = for (i <- 0 until 10) yield time {
- val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]()
- for (i <- 0 until numElements) {
- map(i) = new LargeDummyClass()
- map.estimateSize()
- }
- }
-
- val unsampledTimes = for (i <- 0 until 3) yield time {
- val map = new AppendOnlyMap[Int, LargeDummyClass]()
- for (i <- 0 until numElements) {
- map(i) = new LargeDummyClass()
- SizeEstimator.estimate(map)
- }
- }
-
- println("Base: " + baseTimes)
- println("SizeTracker (sampled): " + sampledTimes)
- println("SizeEstimator (unsampled): " + unsampledTimes)
- }
-
- def time(f: => Unit): Long = {
- val start = System.currentTimeMillis()
- f
- System.currentTimeMillis() - start
- }
-
- private class LargeDummyClass {
- val arr = new Array[Int](100)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 428822949c085..0b7ad184a46d2 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -63,12 +63,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
- map.insert(1, 10)
- map.insert(2, 20)
- map.insert(3, 30)
- map.insert(1, 100)
- map.insert(2, 200)
- map.insert(1, 1000)
+ map.insertAll(Seq(
+ (1, 10),
+ (2, 20),
+ (3, 30),
+ (1, 100),
+ (2, 200),
+ (1, 1000)))
val it = map.iterator
assert(it.hasNext)
val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
@@ -282,7 +283,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
assert(w1.hashCode === w2.hashCode)
}
- (1 to 100000).map(_.toString).foreach { i => map.insert(i, i) }
+ map.insertAll((1 to 100000).iterator.map(_.toString).map(i => (i, i)))
collisionPairs.foreach { case (w1, w2) =>
map.insert(w1, w2)
map.insert(w2, w1)
@@ -355,7 +356,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](
createCombiner, mergeValue, mergeCombiners)
- (1 to 100000).foreach { i => map.insert(i, i) }
+ map.insertAll((1 to 100000).iterator.map(i => (i, i)))
map.insert(null.asInstanceOf[Int], 1)
map.insert(1, null.asInstanceOf[Int])
map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int])
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
new file mode 100644
index 0000000000000..1f33967249654
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/SizeTrackerSuite.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.SizeEstimator
+
+class SizeTrackerSuite extends FunSuite {
+ val NORMAL_ERROR = 0.20
+ val HIGH_ERROR = 0.30
+
+ import SizeTrackerSuite._
+
+ test("vector fixed size insertions") {
+ testVector[Long](10000, i => i.toLong)
+ testVector[(Long, Long)](10000, i => (i.toLong, i.toLong))
+ testVector[LargeDummyClass](10000, i => new LargeDummyClass)
+ }
+
+ test("vector variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testVector[String](10000, i => randString(0, 10))
+ testVector[String](10000, i => randString(0, 100))
+ testVector[String](10000, i => randString(90, 100))
+ }
+
+ test("map fixed size insertions") {
+ testMap[Int, Long](10000, i => (i, i.toLong))
+ testMap[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
+ testMap[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass))
+ }
+
+ test("map variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testMap[Int, String](10000, i => (i, randString(0, 10)))
+ testMap[Int, String](10000, i => (i, randString(0, 100)))
+ testMap[Int, String](10000, i => (i, randString(90, 100)))
+ }
+
+ test("map updates") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testMap[String, Int](10000, i => (randString(0, 10000), i))
+ }
+
+ def testVector[T: ClassTag](numElements: Int, makeElement: Int => T) {
+ val vector = new SizeTrackingVector[T]
+ for (i <- 0 until numElements) {
+ val item = makeElement(i)
+ vector += item
+ expectWithinError(vector, vector.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def testMap[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
+ val map = new SizeTrackingAppendOnlyMap[K, V]
+ for (i <- 0 until numElements) {
+ val (k, v) = makeElement(i)
+ map(k) = v
+ expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
+ val betterEstimatedSize = SizeEstimator.estimate(obj)
+ assert(betterEstimatedSize * (1 - error) < estimatedSize,
+ s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
+ assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
+ s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
+ }
+}
+
+private object SizeTrackerSuite {
+
+ /**
+ * Run speed tests for size tracking collections.
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.size < 1) {
+ println("Usage: SizeTrackerSuite [num elements]")
+ System.exit(1)
+ }
+ val numElements = args(0).toInt
+ vectorSpeedTest(numElements)
+ mapSpeedTest(numElements)
+ }
+
+ /**
+ * Speed test for SizeTrackingVector.
+ *
+ * Results for 100000 elements (possibly non-deterministic):
+ * PrimitiveVector 15 ms
+ * SizeTracker 51 ms
+ * SizeEstimator 2000 ms
+ */
+ def vectorSpeedTest(numElements: Int): Unit = {
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val vector = new PrimitiveVector[LargeDummyClass]
+ for (i <- 0 until numElements) {
+ vector += new LargeDummyClass
+ }
+ }
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val vector = new SizeTrackingVector[LargeDummyClass]
+ for (i <- 0 until numElements) {
+ vector += new LargeDummyClass
+ vector.estimateSize()
+ }
+ }
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val vector = new PrimitiveVector[LargeDummyClass]
+ for (i <- 0 until numElements) {
+ vector += new LargeDummyClass
+ SizeEstimator.estimate(vector)
+ }
+ }
+ printSpeedTestResult("SizeTrackingVector", baseTimes, sampledTimes, unsampledTimes)
+ }
+
+ /**
+ * Speed test for SizeTrackingAppendOnlyMap.
+ *
+ * Results for 100000 elements (possibly non-deterministic):
+ * AppendOnlyMap 30 ms
+ * SizeTracker 41 ms
+ * SizeEstimator 1666 ms
+ */
+ def mapSpeedTest(numElements: Int): Unit = {
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass
+ }
+ }
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass
+ map.estimateSize()
+ }
+ }
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass
+ SizeEstimator.estimate(map)
+ }
+ }
+ printSpeedTestResult("SizeTrackingAppendOnlyMap", baseTimes, sampledTimes, unsampledTimes)
+ }
+
+ def printSpeedTestResult(
+ testName: String,
+ baseTimes: Seq[Long],
+ sampledTimes: Seq[Long],
+ unsampledTimes: Seq[Long]): Unit = {
+ println(s"Average times for $testName (ms):")
+ println(" Base - " + averageTime(baseTimes))
+ println(" SizeTracker (sampled) - " + averageTime(sampledTimes))
+ println(" SizeEstimator (unsampled) - " + averageTime(unsampledTimes))
+ println()
+ }
+
+ def time(f: => Unit): Long = {
+ val start = System.currentTimeMillis()
+ f
+ System.currentTimeMillis() - start
+ }
+
+ def averageTime(v: Seq[Long]): Long = {
+ v.sum / v.size
+ }
+
+ private class LargeDummyClass {
+ val arr = new Array[Int](100)
+ }
+}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 33de24d1ae6d7..38830103d1e8d 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -53,7 +53,7 @@ if [[ ! "$@" =~ --package-only ]]; then
-Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
-Dmaven.javadoc.skip=true \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\
-Dtag=$GIT_TAG -DautoVersionSubmodules=true \
--batch-mode release:prepare
@@ -61,7 +61,7 @@ if [[ ! "$@" =~ --package-only ]]; then
-Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Dmaven.javadoc.skip=true \
- -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl\
release:perform
cd ..
@@ -111,10 +111,10 @@ make_binary_release() {
spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
}
-make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4"
-make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0"
+make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4"
+make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0"
make_binary_release "hadoop2" \
- "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0"
+ "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0"
# Copy data
echo "Copying release tarballs"
diff --git a/dev/run-tests b/dev/run-tests
index 98ec969dc1b37..51e4def0f835a 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -65,7 +65,7 @@ echo "========================================================================="
# (either resolution or compilation) prompts the user for input either q, r,
# etc to quit or retry. This echo is there to make it not block.
if [ -n "$_RUN_SQL_TESTS" ]; then
- echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive -Phive-thriftserver" sbt/sbt clean package \
+ echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \
assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
else
echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \
diff --git a/dev/scalastyle b/dev/scalastyle
index d9f2b91a3a091..a02d06912f238 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,7 +17,7 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
+echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
# Check style with YARN alpha built too
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
diff --git a/docs/configuration.md b/docs/configuration.md
index dac8bb1d52468..2e6c85cc2bcca 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -239,7 +239,7 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.memoryFraction |
- 0.3 |
+ 0.2 |
Fraction of Java heap to use for aggregation and cogroups during shuffles, if
spark.shuffle.spill is true. At any given time, the collective size of
@@ -380,13 +380,13 @@ Apart from these, the following properties are also available, and may be useful
|
spark.serializer.objectStreamReset |
- 10000 |
+ 100 |
When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
objects to prevent writing redundant data, however that stops garbage collection of those
objects. By calling 'reset' you flush that info from the serializer, and allow old
objects to be collected. To turn off this periodic reset set it to a value <= 0.
- By default it will reset the serializer every 10,000 objects.
+ By default it will reset the serializer every 100 objects.
|
@@ -480,6 +480,15 @@ Apart from these, the following properties are also available, and may be useful
increase it if you configure your own old generation size.
+
+ spark.storage.unrollFraction |
+ 0.2 |
+
+ Fraction of spark.storage.memoryFraction to use for unrolling blocks in memory.
+ This is dynamically allocated by dropping existing blocks when there is not enough free
+ storage space to unroll the new block in its entirety.
+ |
+
spark.tachyonStore.baseDir |
System.getProperty("java.io.tmpdir") |
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 36d642f2923b2..38728534a46e0 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -136,7 +136,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD
// Define the schema using a case class.
-// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
+// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
// you can use custom classes that implement the Product interface.
case class Person(name: String, age: Int)
@@ -548,6 +548,7 @@ results = hiveContext.hql("FROM src SELECT key, value").collect()
+
# Writing Language-Integrated Relational Queries
**Language-Integrated queries are currently only supported in Scala.**
@@ -572,199 +573,4 @@ prefixed with a tick (`'`). Implicit conversions turn these symbols into expres
evaluated by the SQL execution engine. A full list of the functions supported can be found in the
[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD).
-
-
-## Running the Thrift JDBC server
-
-The Thrift JDBC server implemented here corresponds to the [`HiveServer2`]
-(https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test
-the JDBC server with the beeline script comes with either Spark or Hive 0.12. In order to use Hive
-you must first run '`sbt/sbt -Phive-thriftserver assembly/assembly`' (or use `-Phive-thriftserver`
-for maven).
-
-To start the JDBC server, run the following in the Spark directory:
-
- ./sbin/start-thriftserver.sh
-
-The default port the server listens on is 10000. You may run
-`./sbin/start-thriftserver.sh --help` for a complete list of all available
-options. Now you can use beeline to test the Thrift JDBC server:
-
- ./bin/beeline
-
-Connect to the JDBC server in beeline with:
-
- beeline> !connect jdbc:hive2://localhost:10000
-
-Beeline will ask you for a username and password. In non-secure mode, simply enter the username on
-your machine and a blank password. For secure mode, please follow the instructions given in the
-[beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients)
-
-Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
-
-You may also use the beeline script comes with Hive.
-
-### Migration Guide for Shark Users
-
-#### Reducer number
-
-In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark
-SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value
-is 200. Users may customize this property via `SET`:
-
-```
-SET spark.sql.shuffle.partitions=10;
-SELECT page, count(*) c FROM logs_last_month_cached
-GROUP BY page ORDER BY c DESC LIMIT 10;
-```
-
-You may also put this property in `hive-site.xml` to override the default value.
-
-For now, the `mapred.reduce.tasks` property is still recognized, and is converted to
-`spark.sql.shuffle.partitions` automatically.
-
-#### Caching
-
-The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no
-longer automcatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to
-let user control table caching explicitly:
-
-```
-CACHE TABLE logs_last_month;
-UNCACHE TABLE logs_last_month;
-```
-
-**NOTE** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary",
-but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be
-cached, you may simply count the table immediately after executing `CACHE TABLE`:
-
-```
-CACHE TABLE logs_last_month;
-SELECT COUNT(1) FROM logs_last_month;
-```
-
-Several caching related features are not supported yet:
-
-* User defined partition level cache eviction policy
-* RDD reloading
-* In-memory cache write through policy
-
-### Compatibility with Apache Hive
-
-#### Deploying in Exising Hive Warehouses
-
-Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
-installations. You do not need to modify your existing Hive Metastore or change the data placement
-or partitioning of your tables.
-
-#### Supported Hive Features
-
-Spark SQL supports the vast majority of Hive features, such as:
-
-* Hive query statements, including:
- * `SELECT`
- * `GROUP BY
- * `ORDER BY`
- * `CLUSTER BY`
- * `SORT BY`
-* All Hive operators, including:
- * Relational operators (`=`, `⇔`, `==`, `<>`, `<`, `>`, `>=`, `<=`, etc)
- * Arthimatic operators (`+`, `-`, `*`, `/`, `%`, etc)
- * Logical operators (`AND`, `&&`, `OR`, `||`, etc)
- * Complex type constructors
- * Mathemtatical functions (`sign`, `ln`, `cos`, etc)
- * String functions (`instr`, `length`, `printf`, etc)
-* User defined functions (UDF)
-* User defined aggregation functions (UDAF)
-* User defined serialization formats (SerDe's)
-* Joins
- * `JOIN`
- * `{LEFT|RIGHT|FULL} OUTER JOIN`
- * `LEFT SEMI JOIN`
- * `CROSS JOIN`
-* Unions
-* Sub queries
- * `SELECT col FROM ( SELECT a + b AS col from t1) t2`
-* Sampling
-* Explain
-* Partitioned tables
-* All Hive DDL Functions, including:
- * `CREATE TABLE`
- * `CREATE TABLE AS SELECT`
- * `ALTER TABLE`
-* Most Hive Data types, including:
- * `TINYINT`
- * `SMALLINT`
- * `INT`
- * `BIGINT`
- * `BOOLEAN`
- * `FLOAT`
- * `DOUBLE`
- * `STRING`
- * `BINARY`
- * `TIMESTAMP`
- * `ARRAY<>`
- * `MAP<>`
- * `STRUCT<>`
-
-#### Unsupported Hive Functionality
-
-Below is a list of Hive features that we don't support yet. Most of these features are rarely used
-in Hive deployments.
-
-**Major Hive Features**
-
-* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
- doesn't support buckets yet.
-
-**Esoteric Hive Features**
-
-* Tables with partitions using different input formats: In Spark SQL, all table partitions need to
- have the same input format.
-* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions
- (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple.
-* `UNIONTYPE`
-* Unique join
-* Single query multi insert
-* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at
- the moment.
-
-**Hive Input/Output Formats**
-
-* File format for CLI: For results showing back to the CLI, Spark SQL only supports TextOutputFormat.
-* Hadoop archive
-
-**Hive Optimizations**
-
-A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are
-not necessary due to Spark SQL's in-memory computational model. Others are slotted for future
-releases of Spark SQL.
-
-* Block level bitmap indexes and virtual columns (used to build indexes)
-* Automatically convert a join to map join: For joining a large table with multiple small tables,
- Hive automatically converts the join into a map join. We are adding this auto conversion in the
- next release.
-* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you
- need to control the degree of parallelism post-shuffle using "SET
- spark.sql.shuffle.partitions=[num_tasks];". We are going to add auto-setting of parallelism in the
- next release.
-* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still
- launches tasks to compute the result.
-* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
-* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint.
-* Merge multiple small files for query results: if the result output contains multiple small files,
- Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS
- metadata. Spark SQL does not support that.
-
-## Running the Spark SQL CLI
-
-The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute
-queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server.
-
-To start the Spark SQL CLI, run the following in the Spark directory:
-
- ./bin/spark-sql
-
-Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
-You may run `./bin/spark-sql --help` for a complete list of all available
-options.
+
\ No newline at end of file
diff --git a/examples/pom.xml b/examples/pom.xml
index c4ed0f5a6a02b..bd1c387c2eb91 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-examples_2.10
- examples
+ examples
jar
Spark Project Examples
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 9f680b27c3308..e6b3cc36702c8 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-flume_2.10
- streaming-flume
+ streaming-flume
jar
Spark Project External Flume
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 25a5c0a4d7d77..4762c50685a93 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-kafka_2.10
- streaming-kafka
+ streaming-kafka
jar
Spark Project External Kafka
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index f31ed655f6779..32c530e600ce0 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-mqtt_2.10
- streaming-mqtt
+ streaming-mqtt
jar
Spark Project External MQTT
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 56bb24c2a072e..637adb0f00da0 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-twitter_2.10
- streaming-twitter
+ streaming-twitter
jar
Spark Project External Twitter
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index 54b0242c54e78..e4d758a04a4cd 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming-zeromq_2.10
- streaming-zeromq
+ streaming-zeromq
jar
Spark Project External ZeroMQ
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 6dd52fc618b1e..7e3bcf29dcfbc 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-graphx_2.10
- graphx
+ graphx
jar
Spark Project GraphX
diff --git a/mllib/pom.xml b/mllib/pom.xml
index f27cf520dc9fa..92b07e2357db1 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-mllib_2.10
- mllib
+ mllib
jar
Spark Project ML Library
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c44173793b39a..954621ee8b933 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -54,6 +54,13 @@ class PythonMLLibAPI extends Serializable {
}
}
+ private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
+ require(bytes.length - offset == 8, "Wrong size byte array for Double")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ bb.getDouble
+ }
+
private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
val packetLength = bytes.length - offset
require(packetLength >= 5, "Byte array too short")
@@ -89,6 +96,22 @@ class PythonMLLibAPI extends Serializable {
Vectors.sparse(size, indices, values)
}
+ /**
+ * Returns an 8-byte array for the input Double.
+ *
+ * Note: we currently do not use a magic byte for double for storage efficiency.
+ * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
+ * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
+ * serialization scheme changes.
+ */
+ private[python] def serializeDouble(double: Double): Array[Byte] = {
+ val bytes = new Array[Byte](8)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(double)
+ bytes
+ }
+
private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
val bytes = new Array[Byte](5 + 8 * len)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index b6e0c4a80e27b..6c7be0a4f1dcb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] (
}
}
- override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
+ override def predict(testData: RDD[Vector]): RDD[Double] = {
+ val bcModel = testData.context.broadcast(this)
+ testData.mapPartitions { iter =>
+ val model = bcModel.value
+ iter.map(model.predict)
+ }
+ }
override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index de22fbb6ffc10..db425d866bbad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -165,18 +165,21 @@ class KMeans private (
val activeCenters = activeRuns.map(r => centers(r)).toArray
val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
+ val bcActiveCenters = sc.broadcast(activeCenters)
+
// Find the sum and count of points mapping to each center
val totalContribs = data.mapPartitions { points =>
- val runs = activeCenters.length
- val k = activeCenters(0).length
- val dims = activeCenters(0)(0).vector.length
+ val thisActiveCenters = bcActiveCenters.value
+ val runs = thisActiveCenters.length
+ val k = thisActiveCenters(0).length
+ val dims = thisActiveCenters(0)(0).vector.length
val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
val counts = Array.fill(runs, k)(0L)
points.foreach { point =>
(0 until runs).foreach { i =>
- val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point)
+ val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i) += cost
sums(i)(bestCenter) += point.vector
counts(i)(bestCenter) += 1
@@ -264,16 +267,17 @@ class KMeans private (
// to their squared distance from that run's current centers
var step = 0
while (step < initializationSteps) {
+ val bcCenters = data.context.broadcast(centers)
val sumCosts = data.flatMap { point =>
(0 until runs).map { r =>
- (r, KMeans.pointCost(centers(r), point))
+ (r, KMeans.pointCost(bcCenters.value(r), point))
}
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
points.flatMap { p =>
(0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
+ rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
}.map((_, p))
}
}.collect()
@@ -286,9 +290,10 @@ class KMeans private (
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
+ val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
(0 until runs).map { r =>
- ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
+ ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
val finalCenters = (0 until runs).map { r =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index fba21aefaaacd..5823cb6e52e7f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -38,7 +38,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm
- points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
+ val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+ points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
}
/** Maps given points to their cluster indices. */
@@ -51,7 +52,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
*/
def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm
- data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
+ val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+ data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
}
private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 7030eeabe400a..9fd760bf78083 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -163,6 +163,7 @@ object GradientDescent extends Logging {
// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)
+ val n = weights.size
/**
* For the first iteration, the regVal will be initialized as sum of weight squares
@@ -172,12 +173,13 @@ object GradientDescent extends Logging {
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
+ val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
- .aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ .aggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
+ val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 7bbed9c8fdbef..179cd4a3f1625 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -195,13 +195,14 @@ object LBFGS extends Logging {
override def calculate(weights: BDV[Double]) = {
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
- val localData = data
val localGradient = gradient
+ val n = weights.length
+ val bcWeights = data.context.broadcast(weights)
- val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
- features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
+ features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
(grad, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
new file mode 100644
index 0000000000000..7ecb409c4a91a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
+
+/**
+ * :: Experimental ::
+ * Trait for random number generators that generate i.i.d. values from a distribution.
+ */
+@Experimental
+trait DistributionGenerator extends Pseudorandom with Serializable {
+
+ /**
+ * Returns an i.i.d. sample as a Double from an underlying distribution.
+ */
+ def nextValue(): Double
+
+ /**
+ * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the
+ * class when applicable for non-locking concurrent usage.
+ */
+ def copy(): DistributionGenerator
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from U[0.0, 1.0]
+ */
+@Experimental
+class UniformGenerator extends DistributionGenerator {
+
+ // XORShiftRandom for better performance. Thread safety isn't necessary here.
+ private val random = new XORShiftRandom()
+
+ override def nextValue(): Double = {
+ random.nextDouble()
+ }
+
+ override def setSeed(seed: Long) = random.setSeed(seed)
+
+ override def copy(): UniformGenerator = new UniformGenerator()
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from the standard normal distribution.
+ */
+@Experimental
+class StandardNormalGenerator extends DistributionGenerator {
+
+ // XORShiftRandom for better performance. Thread safety isn't necessary here.
+ private val random = new XORShiftRandom()
+
+ override def nextValue(): Double = {
+ random.nextGaussian()
+ }
+
+ override def setSeed(seed: Long) = random.setSeed(seed)
+
+ override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
+}
+
+/**
+ * :: Experimental ::
+ * Generates i.i.d. samples from the Poisson distribution with the given mean.
+ *
+ * @param mean mean for the Poisson distribution.
+ */
+@Experimental
+class PoissonGenerator(val mean: Double) extends DistributionGenerator {
+
+ private var rng = new Poisson(mean, new DRand)
+
+ override def nextValue(): Double = rng.nextDouble()
+
+ override def setSeed(seed: Long) {
+ rng = new Poisson(mean, new DRand(seed.toInt))
+ }
+
+ override def copy(): PoissonGenerator = new PoissonGenerator(mean)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
new file mode 100644
index 0000000000000..d7ee2d3f46846
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
@@ -0,0 +1,473 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+/**
+ * :: Experimental ::
+ * Generator methods for creating RDDs comprised of i.i.d samples from some distribution.
+ */
+@Experimental
+object RandomRDDGenerators {
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = {
+ val uniform = new UniformGenerator()
+ randomRDD(sc, uniform, size, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = {
+ uniformRDD(sc, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the uniform distribution on [0.0, 1.0].
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = {
+ uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = {
+ val normal = new StandardNormalGenerator()
+ randomRDD(sc, normal, size, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = {
+ normalRDD(sc, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the standard normal distribution.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalRDD(sc: SparkContext, size: Long): RDD[Double] = {
+ normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonRDD(sc: SparkContext,
+ mean: Double,
+ size: Long,
+ numPartitions: Int,
+ seed: Long): RDD[Double] = {
+ val poisson = new PoissonGenerator(mean)
+ randomRDD(sc, poisson, size, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonRDD(sc: SparkContext, mean: Double, size: Long, numPartitions: Int): RDD[Double] = {
+ poissonRDD(sc, mean, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples from the Poisson distribution with the input mean.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonRDD(sc: SparkContext, mean: Double, size: Long): RDD[Double] = {
+ poissonRDD(sc, mean, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Double] comprised of i.i.d. samples produced by generator.
+ */
+ @Experimental
+ def randomRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ size: Long,
+ numPartitions: Int,
+ seed: Long): RDD[Double] = {
+ new RandomRDD(sc, size, numPartitions, generator, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples produced by generator.
+ */
+ @Experimental
+ def randomRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ size: Long,
+ numPartitions: Int): RDD[Double] = {
+ randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD comprised of i.i.d samples produced by the input DistributionGenerator.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @return RDD[Double] comprised of i.i.d. samples produced by generator.
+ */
+ @Experimental
+ def randomRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ size: Long): RDD[Double] = {
+ randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ // TODO Generate RDD[Vector] from multivariate distributions.
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * uniform distribution on [0.0 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ val uniform = new UniformGenerator()
+ randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * uniform distribution on [0.0 1.0].
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ uniformVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * uniform distribution on [0.0 1.0].
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0].
+ */
+ @Experimental
+ def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = {
+ uniformVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ val uniform = new StandardNormalGenerator()
+ randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * standard normal distribution.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalVectorRDD(sc: SparkContext,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ normalVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * standard normal distribution.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ N(0.0, 1.0).
+ */
+ @Experimental
+ def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = {
+ normalVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonVectorRDD(sc: SparkContext,
+ mean: Double,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ val poisson = new PoissonGenerator(mean)
+ randomVectorRDD(sc, poisson, numRows, numCols, numPartitions, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * Poisson distribution with the input mean.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonVectorRDD(sc: SparkContext,
+ mean: Double,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ poissonVectorRDD(sc, mean, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples drawn from the
+ * Poisson distribution with the input mean.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param mean Mean, or lambda, for the Poisson distribution.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples ~ Pois(mean).
+ */
+ @Experimental
+ def poissonVectorRDD(sc: SparkContext,
+ mean: Double,
+ numRows: Long,
+ numCols: Int): RDD[Vector] = {
+ poissonVectorRDD(sc, mean, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the
+ * input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @param seed Seed for the RNG that generates the seed for the generator in each partition.
+ * @return RDD[Vector] with vectors containing i.i.d samples produced by generator.
+ */
+ @Experimental
+ def randomVectorRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int,
+ seed: Long): RDD[Vector] = {
+ new RandomVectorRDD(sc, numRows, numCols, numPartitions, generator, seed)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the
+ * input DistributionGenerator.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @param numPartitions Number of partitions in the RDD.
+ * @return RDD[Vector] with vectors containing i.i.d samples produced by generator.
+ */
+ @Experimental
+ def randomVectorRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ numRows: Long,
+ numCols: Int,
+ numPartitions: Int): RDD[Vector] = {
+ randomVectorRDD(sc, generator, numRows, numCols, numPartitions, Utils.random.nextLong)
+ }
+
+ /**
+ * :: Experimental ::
+ * Generates an RDD[Vector] with vectors containing i.i.d samples produced by the
+ * input DistributionGenerator.
+ * sc.defaultParallelism used for the number of partitions in the RDD.
+ *
+ * @param sc SparkContext used to create the RDD.
+ * @param generator DistributionGenerator used to populate the RDD.
+ * @param numRows Number of Vectors in the RDD.
+ * @param numCols Number of elements in each Vector.
+ * @return RDD[Vector] with vectors containing i.i.d samples produced by generator.
+ */
+ @Experimental
+ def randomVectorRDD(sc: SparkContext,
+ generator: DistributionGenerator,
+ numRows: Long,
+ numCols: Int): RDD[Vector] = {
+ randomVectorRDD(sc, generator, numRows, numCols,
+ sc.defaultParallelism, Utils.random.nextLong)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
new file mode 100644
index 0000000000000..f13282d07ff92
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector}
+import org.apache.spark.mllib.random.DistributionGenerator
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+import scala.util.Random
+
+private[mllib] class RandomRDDPartition(override val index: Int,
+ val size: Int,
+ val generator: DistributionGenerator,
+ val seed: Long) extends Partition {
+
+ require(size >= 0, "Non-negative partition size required.")
+}
+
+// These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue
+private[mllib] class RandomRDD(@transient sc: SparkContext,
+ size: Long,
+ numPartitions: Int,
+ @transient rng: DistributionGenerator,
+ @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) {
+
+ require(size > 0, "Positive RDD size required.")
+ require(numPartitions > 0, "Positive number of partitions required")
+ require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue,
+ "Partition size cannot exceed Int.MaxValue")
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = {
+ val split = splitIn.asInstanceOf[RandomRDDPartition]
+ RandomRDD.getPointIterator(split)
+ }
+
+ override def getPartitions: Array[Partition] = {
+ RandomRDD.getPartitions(size, numPartitions, rng, seed)
+ }
+}
+
+private[mllib] class RandomVectorRDD(@transient sc: SparkContext,
+ size: Long,
+ vectorSize: Int,
+ numPartitions: Int,
+ @transient rng: DistributionGenerator,
+ @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) {
+
+ require(size > 0, "Positive RDD size required.")
+ require(numPartitions > 0, "Positive number of partitions required")
+ require(vectorSize > 0, "Positive vector size required.")
+ require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue,
+ "Partition size cannot exceed Int.MaxValue")
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = {
+ val split = splitIn.asInstanceOf[RandomRDDPartition]
+ RandomRDD.getVectorIterator(split, vectorSize)
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ RandomRDD.getPartitions(size, numPartitions, rng, seed)
+ }
+}
+
+private[mllib] object RandomRDD {
+
+ def getPartitions(size: Long,
+ numPartitions: Int,
+ rng: DistributionGenerator,
+ seed: Long): Array[Partition] = {
+
+ val partitions = new Array[RandomRDDPartition](numPartitions)
+ var i = 0
+ var start: Long = 0
+ var end: Long = 0
+ val random = new Random(seed)
+ while (i < numPartitions) {
+ end = ((i + 1) * size) / numPartitions
+ partitions(i) = new RandomRDDPartition(i, (end - start).toInt, rng, random.nextLong())
+ start = end
+ i += 1
+ }
+ partitions.asInstanceOf[Array[Partition]]
+ }
+
+ // The RNG has to be reset every time the iterator is requested to guarantee same data
+ // every time the content of the RDD is examined.
+ def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = {
+ val generator = partition.generator.copy()
+ generator.setSeed(partition.seed)
+ Array.fill(partition.size)(generator.nextValue()).toIterator
+ }
+
+ // The RNG has to be reset every time the iterator is requested to guarantee same data
+ // every time the content of the RDD is examined.
+ def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = {
+ val generator = partition.generator.copy()
+ generator.setSeed(partition.seed)
+ Array.fill(partition.size)(new DenseVector(
+ (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index fe41863bce985..54854252d7477 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -56,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weights
+ val bcWeights = testData.context.broadcast(localWeights)
val localIntercept = intercept
-
- testData.map(v => predictPoint(v, localWeights, localIntercept))
+ testData.mapPartitions { iter =>
+ val w = bcWeights.value
+ iter.map(v => predictPoint(v, w, localIntercept))
+ }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index faa675b59cd50..862221d48798a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -92,8 +92,6 @@ public void runLRUsingStaticMethods() {
testRDD.rdd(), 100, 1.0, 1.0);
int numAccurate = validatePrediction(validationData, model);
- System.out.println(numAccurate);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 642843f90204c..d94cfa2fcec81 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -57,4 +57,12 @@ class PythonMLLibAPISuite extends FunSuite {
assert(q.features === p.features)
}
}
+
+ test("double serialization") {
+ for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue)) {
+ val bytes = py.serializeDouble(x)
+ val deser = py.deserializeDouble(bytes)
+ assert(x === deser)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 44b757b6a1fb7..3f6ff859374c7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object LogisticRegressionSuite {
@@ -126,3 +126,19 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LogisticRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 516895d04222d..06cdd04f5fdae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object NaiveBayesSuite {
@@ -96,3 +96,21 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 10
+ val n = 200000
+ val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map { i =>
+ LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
+ }
+ }
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = NaiveBayes.train(examples)
+ val predictions = model.predict(examples.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 886c71dde3af7..65e5df58db4c7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -17,17 +17,16 @@
package org.apache.spark.mllib.classification
-import scala.util.Random
import scala.collection.JavaConversions._
-
-import org.scalatest.FunSuite
+import scala.util.Random
import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
import org.apache.spark.SparkException
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object SVMSuite {
@@ -193,3 +192,19 @@ class SVMSuite extends FunSuite with LocalSparkContext {
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
+
+class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = SVMWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 76a3bdf9b11c8..34bc4537a7b3a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -17,14 +17,16 @@
package org.apache.spark.mllib.clustering
+import scala.util.Random
+
import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class KMeansSuite extends FunSuite with LocalSparkContext {
- import KMeans.{RANDOM, K_MEANS_PARALLEL}
+ import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
test("single cluster") {
val data = sc.parallelize(Array(
@@ -38,26 +40,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// No matter how many runs or iterations we use, we should get one cluster,
// centered at the mean of the points
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
model = KMeans.train(
- data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
}
@@ -100,26 +102,27 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.dense(1.0, 3.0, 4.0)
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.size === 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+ initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
}
@@ -145,25 +148,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
- var model = KMeans.train(data, k=1, maxIterations=1)
+ var model = KMeans.train(data, k = 1, maxIterations = 1)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=2)
+ model = KMeans.train(data, k = 1, maxIterations = 2)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=5)
+ model = KMeans.train(data, k = 1, maxIterations = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
assert(model.clusterCenters.head === center)
- model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+ model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
+ initializationMode = K_MEANS_PARALLEL)
assert(model.clusterCenters.head === center)
data.unpersist()
@@ -183,15 +187,15 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
// it will make at least five passes, and it will give non-zero probability to each
// unselected point as long as it hasn't yet selected all of them
- var model = KMeans.train(rdd, k=5, maxIterations=1)
+ var model = KMeans.train(rdd, k = 5, maxIterations = 1)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
// Iterations of Lloyd's should not change the answer either
- model = KMeans.train(rdd, k=5, maxIterations=10)
+ model = KMeans.train(rdd, k = 5, maxIterations = 10)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
// Neither should more runs
- model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
+ model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
assert(Set(model.clusterCenters: _*) === Set(points: _*))
}
@@ -220,3 +224,22 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
}
}
}
+
+class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble)))
+ }.cache()
+ for (initMode <- Seq(KMeans.RANDOM, KMeans.K_MEANS_PARALLEL)) {
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = KMeans.train(points, 2, 2, 1, initMode)
+ val predictions = model.predict(points).collect()
+ val cost = model.computeCost(points)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index a961f89456a18..325b817980f68 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
+import scala.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
+import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class RowMatrixSuite extends FunSuite with LocalSparkContext {
@@ -193,3 +194,27 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
}
}
}
+
+class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ var mat: RowMatrix = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ val m = 4
+ val n = 200000
+ val rows = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => Vectors.dense(Array.fill(n)(random.nextDouble())))
+ }
+ mat = new RowMatrix(rows)
+ }
+
+ test("task size should be small in svd") {
+ val svd = mat.computeSVD(1, computeU = true)
+ }
+
+ test("task size should be small in summarize") {
+ val summary = mat.computeColumnSummaryStatistics()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 951b4f7c6e6f4..dfb2eb7f0d14e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.optimization
-import scala.util.Random
import scala.collection.JavaConversions._
+import scala.util.Random
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
object GradientDescentSuite {
@@ -46,7 +45,7 @@ object GradientDescentSuite {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
- val unifRand = new scala.util.Random(45)
+ val unifRand = new Random(45)
val rLogis = (0 until nPoints).map { i =>
val u = unifRand.nextDouble()
math.log(u) - math.log(1.0-u)
@@ -144,3 +143,26 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
"should be initialWeightsWithIntercept.")
}
}
+
+class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val (weights, loss) = GradientDescent.runMiniBatchSGD(
+ points,
+ new LogisticGradient,
+ new SquaredL2Updater,
+ 0.1,
+ 2,
+ 1.0,
+ 1.0,
+ Vectors.dense(new Array[Double](n)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index fe7a9033cd5f4..ff414742e8393 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.optimization
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import scala.util.Random
+
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
@@ -230,3 +231,24 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
"The weight differences between LBFGS and GD should be within 2%.")
}
}
+
+class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small") {
+ val m = 10
+ val n = 200000
+ val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble))))
+ }.cache()
+ val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater)
+ .setNumCorrections(1)
+ .setConvergenceTol(1e-12)
+ .setMaxNumIterations(1)
+ .setRegParam(1.0)
+ val random = new Random(0)
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val weights = lbfgs.optimize(examples, Vectors.dense(Array.fill(n)(random.nextDouble)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
new file mode 100644
index 0000000000000..974dec4c0b5ee
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.StatCounter
+
+// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
+class DistributionGeneratorSuite extends FunSuite {
+
+ def apiChecks(gen: DistributionGenerator) {
+
+ // resetting seed should generate the same sequence of random numbers
+ gen.setSeed(42L)
+ val array1 = (0 until 1000).map(_ => gen.nextValue())
+ gen.setSeed(42L)
+ val array2 = (0 until 1000).map(_ => gen.nextValue())
+ assert(array1.equals(array2))
+
+ // newInstance should contain a difference instance of the rng
+ // i.e. setting difference seeds for difference instances produces different sequences of
+ // random numbers.
+ val gen2 = gen.copy()
+ gen.setSeed(0L)
+ val array3 = (0 until 1000).map(_ => gen.nextValue())
+ gen2.setSeed(1L)
+ val array4 = (0 until 1000).map(_ => gen2.nextValue())
+ // Compare arrays instead of elements since individual elements can coincide by chance but the
+ // sequences should differ given two different seeds.
+ assert(!array3.equals(array4))
+
+ // test that setting the same seed in the copied instance produces the same sequence of numbers
+ gen.setSeed(0L)
+ val array5 = (0 until 1000).map(_ => gen.nextValue())
+ gen2.setSeed(0L)
+ val array6 = (0 until 1000).map(_ => gen2.nextValue())
+ assert(array5.equals(array6))
+ }
+
+ def distributionChecks(gen: DistributionGenerator,
+ mean: Double = 0.0,
+ stddev: Double = 1.0,
+ epsilon: Double = 0.01) {
+ for (seed <- 0 until 5) {
+ gen.setSeed(seed.toLong)
+ val sample = (0 until 100000).map { _ => gen.nextValue()}
+ val stats = new StatCounter(sample)
+ assert(math.abs(stats.mean - mean) < epsilon)
+ assert(math.abs(stats.stdev - stddev) < epsilon)
+ }
+ }
+
+ test("UniformGenerator") {
+ val uniform = new UniformGenerator()
+ apiChecks(uniform)
+ // Stddev of uniform distribution = (ub - lb) / math.sqrt(12)
+ distributionChecks(uniform, 0.5, 1 / math.sqrt(12))
+ }
+
+ test("StandardNormalGenerator") {
+ val normal = new StandardNormalGenerator()
+ apiChecks(normal)
+ distributionChecks(normal, 0.0, 1.0)
+ }
+
+ test("PoissonGenerator") {
+ // mean = 0.0 will not pass the API checks since 0.0 is always deterministically produced.
+ for (mean <- List(1.0, 5.0, 100.0)) {
+ val poisson = new PoissonGenerator(mean)
+ apiChecks(poisson)
+ distributionChecks(poisson, mean, math.sqrt(mean), 0.1)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
new file mode 100644
index 0000000000000..6aa4f803df0f7
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.StatCounter
+
+/*
+ * Note: avoid including APIs that do not set the seed for the RNG in unit tests
+ * in order to guarantee deterministic behavior.
+ *
+ * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
+ */
+class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable {
+
+ def testGeneratedRDD(rdd: RDD[Double],
+ expectedSize: Long,
+ expectedNumPartitions: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double = 0.01) {
+ val stats = rdd.stats()
+ assert(expectedSize === stats.count)
+ assert(expectedNumPartitions === rdd.partitions.size)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ // assume test RDDs are small
+ def testGeneratedVectorRDD(rdd: RDD[Vector],
+ expectedRows: Long,
+ expectedColumns: Int,
+ expectedNumPartitions: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double = 0.01) {
+ assert(expectedNumPartitions === rdd.partitions.size)
+ val values = new ArrayBuffer[Double]()
+ rdd.collect.foreach { vector => {
+ assert(vector.size === expectedColumns)
+ values ++= vector.toArray
+ }}
+ assert(expectedRows === values.size / expectedColumns)
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ test("RandomRDD sizes") {
+
+ // some cases where size % numParts != 0 to test getPartitions behaves correctly
+ for ((size, numPartitions) <- List((10000, 6), (12345, 1), (1000, 101))) {
+ val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L)
+ assert(rdd.count() === size)
+ assert(rdd.partitions.size === numPartitions)
+
+ // check that partition sizes are balanced
+ val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble)
+ val partStats = new StatCounter(partSizes)
+ assert(partStats.max - partStats.min <= 1)
+ }
+
+ // size > Int.MaxValue
+ val size = Int.MaxValue.toLong * 100L
+ val numPartitions = 101
+ val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L)
+ assert(rdd.partitions.size === numPartitions)
+ val count = rdd.partitions.foldLeft(0L) { (count, part) =>
+ count + part.asInstanceOf[RandomRDDPartition].size
+ }
+ assert(count === size)
+
+ // size needs to be positive
+ intercept[IllegalArgumentException] { new RandomRDD(sc, 0, 10, new UniformGenerator, 0L) }
+
+ // numPartitions needs to be positive
+ intercept[IllegalArgumentException] { new RandomRDD(sc, 100, 0, new UniformGenerator, 0L) }
+
+ // partition size needs to be <= Int.MaxValue
+ intercept[IllegalArgumentException] {
+ new RandomRDD(sc, Int.MaxValue.toLong * 100L, 99, new UniformGenerator, 0L)
+ }
+ }
+
+ test("randomRDD for different distributions") {
+ val size = 100000L
+ val numPartitions = 10
+ val poissonMean = 100.0
+
+ for (seed <- 0 until 5) {
+ val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed)
+ testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12))
+
+ val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed)
+ testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0)
+
+ val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed)
+ testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1)
+ }
+
+ // mock distribution to check that partitions have unique seeds
+ val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L)
+ assert(random.collect.size === random.collect.distinct.size)
+ }
+
+ test("randomVectorRDD for different distributions") {
+ val rows = 1000L
+ val cols = 100
+ val parts = 10
+ val poissonMean = 100.0
+
+ for (seed <- 0 until 5) {
+ val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed)
+ testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12))
+
+ val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed)
+ testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0)
+
+ val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed)
+ testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1)
+ }
+ }
+}
+
+private[random] class MockDistro extends DistributionGenerator {
+
+ var seed = 0L
+
+ // This allows us to check that each partition has a different seed
+ override def nextValue(): Double = seed.toDouble
+
+ override def setSeed(seed: Long) = this.seed = seed
+
+ override def copy(): MockDistro = new MockDistro
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index bfa42959c8ead..7aa96421aed87 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
+import scala.util.Random
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -113,3 +116,19 @@ class LassoSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
+
+class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LassoWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 7aaad7d7a3e39..4f89112b650c5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.regression
+import scala.util.Random
+
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
@@ -122,3 +125,19 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
}
}
+
+class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = LinearRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67768e17fbe6d..727bbd051ff15 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,11 +17,14 @@
package org.apache.spark.mllib.regression
-import org.scalatest.FunSuite
+import scala.util.Random
import org.jblas.DoubleMatrix
+import org.scalatest.FunSuite
-import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
+ LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
@@ -73,3 +76,19 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
}
+
+class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+
+ test("task size should be small in both training and prediction") {
+ val m = 4
+ val n = 200000
+ val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+ val random = new Random(idx)
+ iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+ }.cache()
+ // If we serialize data directly in the task closure, the size of the serialized task would be
+ // greater than 1MB and hence Spark would throw an error.
+ val model = RidgeRegressionWithSGD.train(points, 2)
+ val predictions = model.predict(points.map(_.features))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
new file mode 100644
index 0000000000000..5e9101cdd3804
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.{Suite, BeforeAndAfterAll}
+
+import org.apache.spark.{SparkConf, SparkContext}
+
+trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite =>
+ @transient var sc: SparkContext = _
+
+ override def beforeAll() {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("test-cluster")
+ .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
+ sc = new SparkContext(conf)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (sc != null) {
+ sc.stop()
+ }
+ super.afterAll()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 0d4868f3d9e42..7857d9e5ee5c4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -20,13 +20,16 @@ package org.apache.spark.mllib.util
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkConf, SparkContext}
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
override def beforeAll() {
- sc = new SparkContext("local", "test")
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ sc = new SparkContext(conf)
super.beforeAll()
}
diff --git a/pom.xml b/pom.xml
index 96a0c60d24de6..20cc248772aac 100644
--- a/pom.xml
+++ b/pom.xml
@@ -95,7 +95,6 @@
sql/catalyst
sql/core
sql/hive
- sql/hive-thriftserver
repl
assembly
external/twitter
@@ -254,9 +253,9 @@
3.3.2
- commons-codec
- commons-codec
- 1.5
+ commons-codec
+ commons-codec
+ 1.5
com.google.code.findbugs
@@ -959,6 +958,30 @@
org.apache.maven.plugins
maven-source-plugin
+
+ org.scalastyle
+ scalastyle-maven-plugin
+ 0.4.0
+
+ false
+ true
+ false
+ false
+ ${basedir}/src/main/scala
+ ${basedir}/src/test/scala
+ scalastyle-config.xml
+ scalastyle-output.xml
+ UTF-8
+
+
+
+ package
+
+ check
+
+
+
+
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index e9220db6b1f9a..5ff88f0dd1cac 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -31,7 +31,6 @@ import com.typesafe.tools.mima.core._
* MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap")
*/
object MimaExcludes {
-
def excludes(version: String) =
version match {
case v if v.startsWith("1.1") =>
@@ -62,6 +61,15 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.storage.MemoryStore.Entry")
) ++
+ Seq(
+ // Renamed putValues -> putArray + putIterator
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.MemoryStore.putValues"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.DiskStore.putValues"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.TachyonStore.putValues")
+ ) ++
Seq(
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this")
) ++
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 21dcb40e9a22f..ae985fe549c85 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -30,13 +30,12 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, sql,
- streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter,
+ val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming,
+ streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter,
streamingZeromq) =
- Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
- "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka",
- "streaming-mqtt", "streaming-twitter", "streaming-zeromq").
- map(ProjectRef(buildLocation, _))
+ Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql",
+ "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt",
+ "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _))
val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) =
Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl")
@@ -102,7 +101,7 @@ object SparkBuild extends PomBuild {
Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
case Some(v) =>
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1)))
- case _ =>
+ case _ =>
}
override val userPropertiesMap = System.getProperties.toMap
@@ -160,7 +159,7 @@ object SparkBuild extends PomBuild {
/* Enable Mima for all projects except spark, hive, catalyst, sql and repl */
// TODO: Add Sql to mima checks
- allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl).contains(x)).
+ allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)).
foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x))
/* Enable Assembly for all assembly projects */
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 024fb881877c9..e8ac9895cf54a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -37,6 +37,15 @@
from py4j.java_collections import ListConverter
+# These are special default configs for PySpark, they will overwrite
+# the default ones for Spark if they are not configured by user.
+DEFAULT_CONFIGS = {
+ "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
+ "spark.serializer.objectStreamReset": 100,
+ "spark.rdd.compress": True,
+}
+
+
class SparkContext(object):
"""
Main entry point for Spark functionality. A SparkContext represents the
@@ -101,7 +110,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
- self._conf.setIfMissing("spark.rdd.compress", "true")
+
# Set any parameters passed directly to us on the conf
if master:
self._conf.setMaster(master)
@@ -112,6 +121,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
if environment:
for key, value in environment.iteritems():
self._conf.setExecutorEnv(key, value)
+ for key, value in DEFAULT_CONFIGS.items():
+ self._conf.setIfMissing(key, value)
# Check that we have at least the required parameters
if not self._conf.contains("spark.master"):
@@ -216,6 +227,13 @@ def setSystemProperty(cls, key, value):
SparkContext._ensure_initialized()
SparkContext._jvm.java.lang.System.setProperty(key, value)
+ @property
+ def version(self):
+ """
+ The version of Spark on which this application is running.
+ """
+ return self._jsc.version()
+
@property
def defaultParallelism(self):
"""
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 43b491a9716fc..8e3ad6b783b6c 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -72,9 +72,9 @@
# Python interpreter must agree on what endian the machine is.
-DENSE_VECTOR_MAGIC = 1
+DENSE_VECTOR_MAGIC = 1
SPARSE_VECTOR_MAGIC = 2
-DENSE_MATRIX_MAGIC = 3
+DENSE_MATRIX_MAGIC = 3
LABELED_POINT_MAGIC = 4
@@ -97,8 +97,28 @@ def _deserialize_numpy_array(shape, ba, offset, dtype=float64):
return ar.copy()
+def _serialize_double(d):
+ """
+ Serialize a double (float or numpy.float64) into a mutually understood format.
+ """
+ if type(d) == float or type(d) == float64:
+ d = float64(d)
+ ba = bytearray(8)
+ _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64)
+ return ba
+ else:
+ raise TypeError("_serialize_double called on non-float input")
+
+
def _serialize_double_vector(v):
- """Serialize a double vector into a mutually understood format.
+ """
+ Serialize a double vector into a mutually understood format.
+
+ Note: we currently do not use a magic byte for double for storage
+ efficiency. This should be reconsidered when we add Ser/De for other
+ 8-byte types (e.g. Long), for safety. The corresponding deserializer,
+ _deserialize_double, needs to be modified as well if the serialization
+ scheme changes.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
@@ -148,6 +168,28 @@ def _serialize_sparse_vector(v):
return ba
+def _deserialize_double(ba, offset=0):
+ """Deserialize a double from a mutually understood format.
+
+ >>> import sys
+ >>> _deserialize_double(_serialize_double(123.0)) == 123.0
+ True
+ >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0
+ True
+ >>> x = sys.float_info.max
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x
+ True
+ >>> y = float64(sys.float_info.max)
+ >>> _deserialize_double(_serialize_double(sys.float_info.max)) == y
+ True
+ """
+ if type(ba) != bytearray:
+ raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba))
+ if len(ba) - offset != 8:
+ raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb)
+ return struct.unpack("d", ba[offset:])[0]
+
+
def _deserialize_double_vector(ba, offset=0):
"""Deserialize a double vector from a mutually understood format.
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 113a082e16721..b84d976114f0d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1687,7 +1687,6 @@ def _jrdd(self):
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
- class_tag = self._prev_jrdd.classTag()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
@@ -1696,8 +1695,7 @@ def _jrdd(self):
bytearray(pickled_command),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator,
- class_tag)
+ broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index a92abbf371f18..8ba51461d106d 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -226,6 +226,15 @@ def test_transforming_cartesian_result(self):
cart = rdd1.cartesian(rdd2)
result = cart.map(lambda (x, y): x + y).collect()
+ def test_transforming_pickle_file(self):
+ # Regression test for SPARK-2601
+ data = self.sc.parallelize(["Hello", "World!"])
+ tempFile = tempfile.NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsPickleFile(tempFile.name)
+ pickled_file = self.sc.pickleFile(tempFile.name)
+ pickled_file.map(lambda x: x).collect()
+
def test_cartesian_on_textfile(self):
# Regression test for
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh
deleted file mode 100755
index 8398e6f19b511..0000000000000
--- a/sbin/start-thriftserver.sh
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env bash
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-#
-# Shell script for starting the Spark SQL Thrift server
-
-# Enter posix mode for bash
-set -o posix
-
-# Figure out where Spark is installed
-FWDIR="$(cd `dirname $0`/..; pwd)"
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- echo "Usage: ./sbin/start-thriftserver [options]"
- $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit 0
-fi
-
-CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
-exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 531bfddbf237b..6decde3fcd62d 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -32,7 +32,7 @@
Spark Project Catalyst
http://spark.apache.org/
- catalyst
+ catalyst
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 67a8ce9b88c3f..47c7ad076ad07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -50,6 +50,7 @@ trait HiveTypeCoercion {
StringToIntegralCasts ::
FunctionArgumentConversion ::
CastNulls ::
+ Division ::
Nil
/**
@@ -317,6 +318,23 @@ trait HiveTypeCoercion {
}
}
+ /**
+ * Hive only performs integral division with the DIV operator. The arguments to / are always
+ * converted to fractional types.
+ */
+ object Division extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes who's children have not been resolved yet.
+ case e if !e.childrenResolved => e
+
+ // Decimal and Double remain the same
+ case d: Divide if d.dataType == DoubleType => d
+ case d: Divide if d.dataType == DecimalType => d
+
+ case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
+ }
+ }
+
/**
* Ensures that NullType gets casted to some other types under certain circumstances.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index a357c6ffb8977..1d5f033f0d274 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -43,7 +43,8 @@ case class NativeCommand(cmd: String) extends Command {
*/
case class SetCommand(key: Option[String], value: Option[String]) extends Command {
override def output = Seq(
- BoundReference(1, AttributeReference("", StringType, nullable = false)()))
+ BoundReference(0, AttributeReference("key", StringType, nullable = false)()),
+ BoundReference(1, AttributeReference("value", StringType, nullable = false)()))
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index d607eed1bea89..0a27cce337482 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest {
Literal(10) as Symbol("2*3+4"),
Literal(14) as Symbol("2*(3+4)"))
.where(Literal(true))
- .groupBy(Literal(3))(Literal(3) as Symbol("9/3"))
+ .groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3"))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 3a038a2db6173..c309c43804d97 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -32,7 +32,7 @@
Spark Project SQL
http://spark.apache.org/
- sql
+ sql
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 41920c00b5a2c..2b787e14f3f15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -30,13 +30,12 @@ import scala.collection.JavaConverters._
* SQLConf is thread-safe (internally synchronized so safe to be used in multiple threads).
*/
trait SQLConf {
- import SQLConf._
/** ************************ Spark SQL Params/Hints ******************* */
// TODO: refactor so that these hints accessors don't pollute the name space of SQLContext?
/** Number of partitions to use for shuffle operators. */
- private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
+ private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
@@ -44,10 +43,11 @@ trait SQLConf {
* effectively disables auto conversion.
* Hive setting: hive.auto.convert.join.noconditionaltask.size.
*/
- private[spark] def autoConvertJoinSize: Int = get(AUTO_CONVERT_JOIN_SIZE, "10000").toInt
+ private[spark] def autoConvertJoinSize: Int =
+ get("spark.sql.auto.convert.join.size", "10000").toInt
/** A comma-separated list of table names marked to be broadcasted during joins. */
- private[spark] def joinBroadcastTables: String = get(JOIN_BROADCAST_TABLES, "")
+ private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "")
/** ********************** SQLConf functionality methods ************ */
@@ -61,7 +61,7 @@ trait SQLConf {
def set(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
- require(value != null, s"value cannot be null for $key")
+ require(value != null, s"value cannot be null for ${key}")
settings.put(key, value)
}
@@ -90,13 +90,3 @@ trait SQLConf {
}
}
-
-object SQLConf {
- val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
- val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
- val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
-
- object Deprecated {
- val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 9293239131d52..98d2f89c8ae71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.execution
-import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.{Row, SQLConf, SQLContext}
+import org.apache.spark.sql.{Row, SQLContext}
trait Command {
/**
@@ -45,53 +44,28 @@ trait Command {
case class SetCommand(
key: Option[String], value: Option[String], output: Seq[Attribute])(
@transient context: SQLContext)
- extends LeafNode with Command with Logging {
+ extends LeafNode with Command {
- override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match {
+ override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match {
// Set value for key k.
case (Some(k), Some(v)) =>
- if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
- logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
- s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- context.set(SQLConf.SHUFFLE_PARTITIONS, v)
- Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")
- } else {
- context.set(k, v)
- Array(s"$k=$v")
- }
+ context.set(k, v)
+ Array(k -> v)
// Query the value bound to key k.
case (Some(k), _) =>
- // TODO (lian) This is just a workaround to make the Simba ODBC driver work.
- // Should remove this once we get the ODBC driver updated.
- if (k == "-v") {
- val hiveJars = Seq(
- "hive-exec-0.12.0.jar",
- "hive-service-0.12.0.jar",
- "hive-common-0.12.0.jar",
- "hive-hwi-0.12.0.jar",
- "hive-0.12.0.jar").mkString(":")
-
- Array(
- "system:java.class.path=" + hiveJars,
- "system:sun.java.command=shark.SharkServer2")
- }
- else {
- Array(s"$k=${context.getOption(k).getOrElse("")}")
- }
+ Array(k -> context.getOption(k).getOrElse(""))
// Query all key-value pairs that are set in the SQLConf of the context.
case (None, None) =>
- context.getAll.map { case (k, v) =>
- s"$k=$v"
- }
+ context.getAll
case _ =>
throw new IllegalArgumentException()
}
def execute(): RDD[Row] = {
- val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) }
+ val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) }
context.sparkContext.parallelize(rows, 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 1a58d73d9e7f4..08293f7f0ca30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -54,10 +54,10 @@ class SQLConfSuite extends QueryTest {
assert(get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
- sql("set some.property=20")
- assert(get("some.property", "0") == "20")
- sql("set some.property = 40")
- assert(get("some.property", "0") == "40")
+ sql("set mapred.reduce.tasks=20")
+ assert(get("mapred.reduce.tasks", "0") == "20")
+ sql("set mapred.reduce.tasks = 40")
+ assert(get("mapred.reduce.tasks", "0") == "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
@@ -70,9 +70,4 @@ class SQLConfSuite extends QueryTest {
clear()
}
- test("deprecated property") {
- clear()
- sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
- assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10")
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index de9e8aa4f62ed..6736189c96d4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -424,25 +424,25 @@ class SQLQuerySuite extends QueryTest {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Seq(Seq(s"$testKey=$testVal"))
+ Seq(Seq(testKey, testVal))
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Seq(s"$testKey=$testVal"),
- Seq(s"${testKey + testKey}=${testVal + testVal}"))
+ Seq(testKey, testVal),
+ Seq(testKey + testKey, testVal + testVal))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Seq(Seq(s"$testKey=$testVal"))
+ Seq(Seq(testKey, testVal))
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Seq(Seq(s"$nonexistentKey="))
+ Seq(Seq(nonexistentKey, ""))
)
clear()
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
deleted file mode 100644
index 7fac90fdc596d..0000000000000
--- a/sql/hive-thriftserver/pom.xml
+++ /dev/null
@@ -1,82 +0,0 @@
-
-
-
-
- 4.0.0
-
- org.apache.spark
- spark-parent
- 1.1.0-SNAPSHOT
- ../../pom.xml
-
-
- org.apache.spark
- spark-hive-thriftserver_2.10
- jar
- Spark Project Hive
- http://spark.apache.org/
-
- hive-thriftserver
-
-
-
-
- org.apache.spark
- spark-hive_${scala.binary.version}
- ${project.version}
-
-
- org.spark-project.hive
- hive-cli
- ${hive.version}
-
-
- org.spark-project.hive
- hive-jdbc
- ${hive.version}
-
-
- org.spark-project.hive
- hive-beeline
- ${hive.version}
-
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
-
-
-
- target/scala-${scala.binary.version}/classes
- target/scala-${scala.binary.version}/test-classes
-
-
- org.scalatest
- scalatest-maven-plugin
-
-
- org.apache.maven.plugins
- maven-deploy-plugin
-
- true
-
-
-
-
-
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
deleted file mode 100644
index ddbc2a79fb512..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import scala.collection.JavaConversions._
-
-import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService
-import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
-
-import org.apache.spark.sql.Logging
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-
-/**
- * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a
- * `HiveThriftServer2` thrift server.
- */
-private[hive] object HiveThriftServer2 extends Logging {
- var LOG = LogFactory.getLog(classOf[HiveServer2])
-
- def main(args: Array[String]) {
- val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2")
-
- if (!optionsProcessor.process(args)) {
- logger.warn("Error starting HiveThriftServer2 with given arguments")
- System.exit(-1)
- }
-
- val ss = new SessionState(new HiveConf(classOf[SessionState]))
-
- // Set all properties specified via command line.
- val hiveConf: HiveConf = ss.getConf
- hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) =>
- logger.debug(s"HiveConf var: $k=$v")
- }
-
- SessionState.start(ss)
-
- logger.info("Starting SparkContext")
- SparkSQLEnv.init()
- SessionState.start(ss)
-
- Runtime.getRuntime.addShutdownHook(
- new Thread() {
- override def run() {
- SparkSQLEnv.sparkContext.stop()
- }
- }
- )
-
- try {
- val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
- server.init(hiveConf)
- server.start()
- logger.info("HiveThriftServer2 started")
- } catch {
- case e: Exception =>
- logger.error("Error starting HiveThriftServer2", e)
- System.exit(-1)
- }
- }
-}
-
-private[hive] class HiveThriftServer2(hiveContext: HiveContext)
- extends HiveServer2
- with ReflectedCompositeService {
-
- override def init(hiveConf: HiveConf) {
- val sparkSqlCliService = new SparkSQLCLIService(hiveContext)
- setSuperField(this, "cliService", sparkSqlCliService)
- addService(sparkSqlCliService)
-
- val thriftCliService = new ThriftBinaryCLIService(sparkSqlCliService)
- setSuperField(this, "thriftCLIService", thriftCliService)
- addService(thriftCliService)
-
- initCompositeService(hiveConf)
- }
-}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
deleted file mode 100644
index 599294dfbb7d7..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ReflectionUtils.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-private[hive] object ReflectionUtils {
- def setSuperField(obj : Object, fieldName: String, fieldValue: Object) {
- setAncestorField(obj, 1, fieldName, fieldValue)
- }
-
- def setAncestorField(obj: AnyRef, level: Int, fieldName: String, fieldValue: AnyRef) {
- val ancestor = Iterator.iterate[Class[_]](obj.getClass)(_.getSuperclass).drop(level).next()
- val field = ancestor.getDeclaredField(fieldName)
- field.setAccessible(true)
- field.set(obj, fieldValue)
- }
-
- def getSuperField[T](obj: AnyRef, fieldName: String): T = {
- getAncestorField[T](obj, 1, fieldName)
- }
-
- def getAncestorField[T](clazz: Object, level: Int, fieldName: String): T = {
- val ancestor = Iterator.iterate[Class[_]](clazz.getClass)(_.getSuperclass).drop(level).next()
- val field = ancestor.getDeclaredField(fieldName)
- field.setAccessible(true)
- field.get(clazz).asInstanceOf[T]
- }
-
- def invokeStatic(clazz: Class[_], methodName: String, args: (Class[_], AnyRef)*): AnyRef = {
- invoke(clazz, null, methodName, args: _*)
- }
-
- def invoke(
- clazz: Class[_],
- obj: AnyRef,
- methodName: String,
- args: (Class[_], AnyRef)*): AnyRef = {
-
- val (types, values) = args.unzip
- val method = clazz.getDeclaredMethod(methodName, types: _*)
- method.setAccessible(true)
- method.invoke(obj, values.toSeq: _*)
- }
-}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
deleted file mode 100755
index 27268ecb923e9..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ /dev/null
@@ -1,344 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import scala.collection.JavaConversions._
-
-import java.io._
-import java.util.{ArrayList => JArrayList}
-
-import jline.{ConsoleReader, History}
-import org.apache.commons.lang.StringUtils
-import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor}
-import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
-import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils}
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.ql.Driver
-import org.apache.hadoop.hive.ql.exec.Utilities
-import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
-import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.hadoop.hive.shims.ShimLoader
-import org.apache.thrift.transport.TSocket
-
-import org.apache.spark.sql.Logging
-
-private[hive] object SparkSQLCLIDriver {
- private var prompt = "spark-sql"
- private var continuedPrompt = "".padTo(prompt.length, ' ')
- private var transport:TSocket = _
-
- installSignalHandler()
-
- /**
- * Install an interrupt callback to cancel all Spark jobs. In Hive's CliDriver#processLine(),
- * a signal handler will invoke this registered callback if a Ctrl+C signal is detected while
- * a command is being processed by the current thread.
- */
- def installSignalHandler() {
- HiveInterruptUtils.add(new HiveInterruptCallback {
- override def interrupt() {
- // Handle remote execution mode
- if (SparkSQLEnv.sparkContext != null) {
- SparkSQLEnv.sparkContext.cancelAllJobs()
- } else {
- if (transport != null) {
- // Force closing of TCP connection upon session termination
- transport.getSocket.close()
- }
- }
- }
- })
- }
-
- def main(args: Array[String]) {
- val oproc = new OptionsProcessor()
- if (!oproc.process_stage1(args)) {
- System.exit(1)
- }
-
- // NOTE: It is critical to do this here so that log4j is reinitialized
- // before any of the other core hive classes are loaded
- var logInitFailed = false
- var logInitDetailMessage: String = null
- try {
- logInitDetailMessage = LogUtils.initHiveLog4j()
- } catch {
- case e: LogInitializationException =>
- logInitFailed = true
- logInitDetailMessage = e.getMessage
- }
-
- val sessionState = new CliSessionState(new HiveConf(classOf[SessionState]))
-
- sessionState.in = System.in
- try {
- sessionState.out = new PrintStream(System.out, true, "UTF-8")
- sessionState.info = new PrintStream(System.err, true, "UTF-8")
- sessionState.err = new PrintStream(System.err, true, "UTF-8")
- } catch {
- case e: UnsupportedEncodingException => System.exit(3)
- }
-
- if (!oproc.process_stage2(sessionState)) {
- System.exit(2)
- }
-
- if (!sessionState.getIsSilent) {
- if (logInitFailed) System.err.println(logInitDetailMessage)
- else SessionState.getConsole.printInfo(logInitDetailMessage)
- }
-
- // Set all properties specified via command line.
- val conf: HiveConf = sessionState.getConf
- sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] =>
- conf.set(item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
- sessionState.getOverriddenConfigurations.put(
- item.getKey.asInstanceOf[String], item.getValue.asInstanceOf[String])
- }
-
- SessionState.start(sessionState)
-
- // Clean up after we exit
- Runtime.getRuntime.addShutdownHook(
- new Thread() {
- override def run() {
- SparkSQLEnv.stop()
- }
- }
- )
-
- // "-h" option has been passed, so connect to Hive thrift server.
- if (sessionState.getHost != null) {
- sessionState.connect()
- if (sessionState.isRemoteMode) {
- prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt
- continuedPrompt = "".padTo(prompt.length, ' ')
- }
- }
-
- if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) {
- // Hadoop-20 and above - we need to augment classpath using hiveconf
- // components.
- // See also: code in ExecDriver.java
- var loader = conf.getClassLoader
- val auxJars = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEAUXJARS)
- if (StringUtils.isNotBlank(auxJars)) {
- loader = Utilities.addToClassPath(loader, StringUtils.split(auxJars, ","))
- }
- conf.setClassLoader(loader)
- Thread.currentThread().setContextClassLoader(loader)
- }
-
- val cli = new SparkSQLCLIDriver
- cli.setHiveVariables(oproc.getHiveVariables)
-
- // TODO work around for set the log output to console, because the HiveContext
- // will set the output into an invalid buffer.
- sessionState.in = System.in
- try {
- sessionState.out = new PrintStream(System.out, true, "UTF-8")
- sessionState.info = new PrintStream(System.err, true, "UTF-8")
- sessionState.err = new PrintStream(System.err, true, "UTF-8")
- } catch {
- case e: UnsupportedEncodingException => System.exit(3)
- }
-
- // Execute -i init files (always in silent mode)
- cli.processInitFiles(sessionState)
-
- if (sessionState.execString != null) {
- System.exit(cli.processLine(sessionState.execString))
- }
-
- try {
- if (sessionState.fileName != null) {
- System.exit(cli.processFile(sessionState.fileName))
- }
- } catch {
- case e: FileNotFoundException =>
- System.err.println(s"Could not open input file for reading. (${e.getMessage})")
- System.exit(3)
- }
-
- val reader = new ConsoleReader()
- reader.setBellEnabled(false)
- // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true)))
- CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e))
-
- val historyDirectory = System.getProperty("user.home")
-
- try {
- if (new File(historyDirectory).exists()) {
- val historyFile = historyDirectory + File.separator + ".hivehistory"
- reader.setHistory(new History(new File(historyFile)))
- } else {
- System.err.println("WARNING: Directory for Hive history file: " + historyDirectory +
- " does not exist. History will not be available during this session.")
- }
- } catch {
- case e: Exception =>
- System.err.println("WARNING: Encountered an error while trying to initialize Hive's " +
- "history file. History will not be available during this session.")
- System.err.println(e.getMessage)
- }
-
- val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport")
- clientTransportTSocketField.setAccessible(true)
-
- transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket]
-
- var ret = 0
- var prefix = ""
- val currentDB = ReflectionUtils.invokeStatic(classOf[CliDriver], "getFormattedDb",
- classOf[HiveConf] -> conf, classOf[CliSessionState] -> sessionState)
-
- def promptWithCurrentDB = s"$prompt$currentDB"
- def continuedPromptWithDBSpaces = continuedPrompt + ReflectionUtils.invokeStatic(
- classOf[CliDriver], "spacesForString", classOf[String] -> currentDB)
-
- var currentPrompt = promptWithCurrentDB
- var line = reader.readLine(currentPrompt + "> ")
-
- while (line != null) {
- if (prefix.nonEmpty) {
- prefix += '\n'
- }
-
- if (line.trim().endsWith(";") && !line.trim().endsWith("\\;")) {
- line = prefix + line
- ret = cli.processLine(line, true)
- prefix = ""
- currentPrompt = promptWithCurrentDB
- } else {
- prefix = prefix + line
- currentPrompt = continuedPromptWithDBSpaces
- }
-
- line = reader.readLine(currentPrompt + "> ")
- }
-
- sessionState.close()
-
- System.exit(ret)
- }
-}
-
-private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
- private val sessionState = SessionState.get().asInstanceOf[CliSessionState]
-
- private val LOG = LogFactory.getLog("CliDriver")
-
- private val console = new SessionState.LogHelper(LOG)
-
- private val conf: Configuration =
- if (sessionState != null) sessionState.getConf else new Configuration()
-
- // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver
- // because the Hive unit tests do not go through the main() code path.
- if (!sessionState.isRemoteMode) {
- SparkSQLEnv.init()
- }
-
- override def processCmd(cmd: String): Int = {
- val cmd_trimmed: String = cmd.trim()
- val tokens: Array[String] = cmd_trimmed.split("\\s+")
- val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
- if (cmd_trimmed.toLowerCase.equals("quit") ||
- cmd_trimmed.toLowerCase.equals("exit") ||
- tokens(0).equalsIgnoreCase("source") ||
- cmd_trimmed.startsWith("!") ||
- tokens(0).toLowerCase.equals("list") ||
- sessionState.isRemoteMode) {
- val start = System.currentTimeMillis()
- super.processCmd(cmd)
- val end = System.currentTimeMillis()
- val timeTaken: Double = (end - start) / 1000.0
- console.printInfo(s"Time taken: $timeTaken seconds")
- 0
- } else {
- var ret = 0
- val hconf = conf.asInstanceOf[HiveConf]
- val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf)
-
- if (proc != null) {
- if (proc.isInstanceOf[Driver]) {
- val driver = new SparkSQLDriver
-
- driver.init()
- val out = sessionState.out
- val start:Long = System.currentTimeMillis()
- if (sessionState.getIsVerbose) {
- out.println(cmd)
- }
-
- ret = driver.run(cmd).getResponseCode
- if (ret != 0) {
- driver.close()
- return ret
- }
-
- val res = new JArrayList[String]()
-
- if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) {
- // Print the column names.
- Option(driver.getSchema.getFieldSchemas).map { fields =>
- out.println(fields.map(_.getName).mkString("\t"))
- }
- }
-
- try {
- while (!out.checkError() && driver.getResults(res)) {
- res.foreach(out.println)
- res.clear()
- }
- } catch {
- case e:IOException =>
- console.printError(
- s"""Failed with exception ${e.getClass.getName}: ${e.getMessage}
- |${org.apache.hadoop.util.StringUtils.stringifyException(e)}
- """.stripMargin)
- ret = 1
- }
-
- val cret = driver.close()
- if (ret == 0) {
- ret = cret
- }
-
- val end = System.currentTimeMillis()
- if (end > start) {
- val timeTaken:Double = (end - start) / 1000.0
- console.printInfo(s"Time taken: $timeTaken seconds", null)
- }
-
- // Destroy the driver to release all the locks.
- driver.destroy()
- } else {
- if (sessionState.getIsVerbose) {
- sessionState.out.println(tokens(0) + " " + cmd_1)
- }
- ret = proc.run(cmd_1).getResponseCode
- }
- }
- ret
- }
- }
-}
-
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
deleted file mode 100644
index 42cbf363b274f..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import scala.collection.JavaConversions._
-
-import java.io.IOException
-import java.util.{List => JList}
-import javax.security.auth.login.LoginException
-
-import org.apache.commons.logging.Log
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.shims.ShimLoader
-import org.apache.hive.service.Service.STATE
-import org.apache.hive.service.auth.HiveAuthFactory
-import org.apache.hive.service.cli.CLIService
-import org.apache.hive.service.{AbstractService, Service, ServiceException}
-
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-
-private[hive] class SparkSQLCLIService(hiveContext: HiveContext)
- extends CLIService
- with ReflectedCompositeService {
-
- override def init(hiveConf: HiveConf) {
- setSuperField(this, "hiveConf", hiveConf)
-
- val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext)
- setSuperField(this, "sessionManager", sparkSqlSessionManager)
- addService(sparkSqlSessionManager)
-
- try {
- HiveAuthFactory.loginFromKeytab(hiveConf)
- val serverUserName = ShimLoader.getHadoopShims
- .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf))
- setSuperField(this, "serverUserName", serverUserName)
- } catch {
- case e @ (_: IOException | _: LoginException) =>
- throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
- }
-
- initCompositeService(hiveConf)
- }
-}
-
-private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
- def initCompositeService(hiveConf: HiveConf) {
- // Emulating `CompositeService.init(hiveConf)`
- val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
- serviceList.foreach(_.init(hiveConf))
-
- // Emulating `AbstractService.init(hiveConf)`
- invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
- setAncestorField(this, 3, "hiveConf", hiveConf)
- invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
- getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
- }
-}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
deleted file mode 100644
index 5202aa9903e03..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import scala.collection.JavaConversions._
-
-import java.util.{ArrayList => JArrayList}
-
-import org.apache.commons.lang.exception.ExceptionUtils
-import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
-import org.apache.hadoop.hive.ql.Driver
-import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse
-
-import org.apache.spark.sql.Logging
-import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
-
-private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext)
- extends Driver with Logging {
-
- private var tableSchema: Schema = _
- private var hiveResponse: Seq[String] = _
-
- override def init(): Unit = {
- }
-
- private def getResultSetSchema(query: context.QueryExecution): Schema = {
- val analyzed = query.analyzed
- logger.debug(s"Result Schema: ${analyzed.output}")
- if (analyzed.output.size == 0) {
- new Schema(new FieldSchema("Response code", "string", "") :: Nil, null)
- } else {
- val fieldSchemas = analyzed.output.map { attr =>
- new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
- }
-
- new Schema(fieldSchemas, null)
- }
- }
-
- override def run(command: String): CommandProcessorResponse = {
- val execution = context.executePlan(context.hql(command).logicalPlan)
-
- // TODO unify the error code
- try {
- hiveResponse = execution.stringResult()
- tableSchema = getResultSetSchema(execution)
- new CommandProcessorResponse(0)
- } catch {
- case cause: Throwable =>
- logger.error(s"Failed in [$command]", cause)
- new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null)
- }
- }
-
- override def close(): Int = {
- hiveResponse = null
- tableSchema = null
- 0
- }
-
- override def getSchema: Schema = tableSchema
-
- override def getResults(res: JArrayList[String]): Boolean = {
- if (hiveResponse == null) {
- false
- } else {
- res.addAll(hiveResponse)
- hiveResponse = null
- true
- }
- }
-
- override def destroy() {
- super.destroy()
- hiveResponse = null
- tableSchema = null
- }
-}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
deleted file mode 100644
index 451c3bd7b9352..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import org.apache.hadoop.hive.ql.session.SessionState
-
-import org.apache.spark.scheduler.{SplitInfo, StatsReportListener}
-import org.apache.spark.sql.Logging
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.{SparkConf, SparkContext}
-
-/** A singleton object for the master program. The slaves should not access this. */
-private[hive] object SparkSQLEnv extends Logging {
- logger.debug("Initializing SparkSQLEnv")
-
- var hiveContext: HiveContext = _
- var sparkContext: SparkContext = _
-
- def init() {
- if (hiveContext == null) {
- sparkContext = new SparkContext(new SparkConf()
- .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}"))
-
- sparkContext.addSparkListener(new StatsReportListener())
-
- hiveContext = new HiveContext(sparkContext) {
- @transient override lazy val sessionState = SessionState.get()
- @transient override lazy val hiveconf = sessionState.getConf
- }
- }
- }
-
- /** Cleans up and shuts down the Spark SQL environments. */
- def stop() {
- logger.debug("Shutting down Spark SQL Environment")
- // Stop the SparkContext
- if (SparkSQLEnv.sparkContext != null) {
- sparkContext.stop()
- sparkContext = null
- hiveContext = null
- }
- }
-}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
deleted file mode 100644
index 6b3275b4eaf04..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import java.util.concurrent.Executors
-
-import org.apache.commons.logging.Log
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hive.service.cli.session.SessionManager
-
-import org.apache.spark.sql.hive.HiveContext
-import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager
-
-private[hive] class SparkSQLSessionManager(hiveContext: HiveContext)
- extends SessionManager
- with ReflectedCompositeService {
-
- override def init(hiveConf: HiveConf) {
- setSuperField(this, "hiveConf", hiveConf)
-
- val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
- setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
- getAncestorField[Log](this, 3, "LOG").info(
- s"HiveServer2: Async execution pool size $backgroundPoolSize")
-
- val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext)
- setSuperField(this, "operationManager", sparkSqlOperationManager)
- addService(sparkSqlOperationManager)
-
- initCompositeService(hiveConf)
- }
-}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
deleted file mode 100644
index a4e1f3e762e89..0000000000000
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver.server
-
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.math.{random, round}
-
-import java.sql.Timestamp
-import java.util.{Map => JMap}
-
-import org.apache.hadoop.hive.common.`type`.HiveDecimal
-import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.hive.service.cli._
-import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
-import org.apache.hive.service.cli.session.HiveSession
-
-import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.hive.thriftserver.ReflectionUtils
-import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
-import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow}
-
-/**
- * Executes queries using Spark SQL, and maintains a list of handles to active queries.
- */
-class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging {
- val handleToOperation = ReflectionUtils
- .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
-
- override def newExecuteStatementOperation(
- parentSession: HiveSession,
- statement: String,
- confOverlay: JMap[String, String],
- async: Boolean): ExecuteStatementOperation = synchronized {
-
- val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) {
- private var result: SchemaRDD = _
- private var iter: Iterator[SparkRow] = _
- private var dataTypes: Array[DataType] = _
-
- def close(): Unit = {
- // RDDs will be cleaned automatically upon garbage collection.
- logger.debug("CLOSING")
- }
-
- def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = {
- if (!iter.hasNext) {
- new RowSet()
- } else {
- val maxRows = maxRowsL.toInt // Do you really want a row batch larger than Int Max? No.
- var curRow = 0
- var rowSet = new ArrayBuffer[Row](maxRows)
-
- while (curRow < maxRows && iter.hasNext) {
- val sparkRow = iter.next()
- val row = new Row()
- var curCol = 0
-
- while (curCol < sparkRow.length) {
- dataTypes(curCol) match {
- case StringType =>
- row.addString(sparkRow(curCol).asInstanceOf[String])
- case IntegerType =>
- row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
- case BooleanType =>
- row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
- case DoubleType =>
- row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
- case FloatType =>
- row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
- case DecimalType =>
- val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
- row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
- case LongType =>
- row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
- case ByteType =>
- row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
- case ShortType =>
- row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
- case TimestampType =>
- row.addColumnValue(
- ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
- case BinaryType | _: ArrayType | _: StructType | _: MapType =>
- val hiveString = result
- .queryExecution
- .asInstanceOf[HiveContext#QueryExecution]
- .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
- row.addColumnValue(ColumnValue.stringValue(hiveString))
- }
- curCol += 1
- }
- rowSet += row
- curRow += 1
- }
- new RowSet(rowSet, 0)
- }
- }
-
- def getResultSetSchema: TableSchema = {
- logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}")
- if (result.queryExecution.analyzed.output.size == 0) {
- new TableSchema(new FieldSchema("Result", "string", "") :: Nil)
- } else {
- val schema = result.queryExecution.analyzed.output.map { attr =>
- new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "")
- }
- new TableSchema(schema)
- }
- }
-
- def run(): Unit = {
- logger.info(s"Running query '$statement'")
- setState(OperationState.RUNNING)
- try {
- result = hiveContext.hql(statement)
- logger.debug(result.queryExecution.toString())
- val groupId = round(random * 1000000).toString
- hiveContext.sparkContext.setJobGroup(groupId, statement)
- iter = result.queryExecution.toRdd.toLocalIterator
- dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
- setHasResultSet(true)
- } catch {
- // Actually do need to catch Throwable as some failures don't inherit from Exception and
- // HiveServer will silently swallow them.
- case e: Throwable =>
- logger.error("Error executing query:",e)
- throw new HiveSQLException(e.toString)
- }
- setState(OperationState.FINISHED)
- }
- }
-
- handleToOperation.put(operation.getHandle, operation)
- operation
- }
-}
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
deleted file mode 100644
index 850f8014b6f05..0000000000000
--- a/sql/hive-thriftserver/src/test/resources/data/files/small_kv.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-238val_238
-86val_86
-311val_311
-27val_27
-165val_165
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
deleted file mode 100644
index b90670a796b81..0000000000000
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import java.io.{BufferedReader, InputStreamReader, PrintWriter}
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.sql.hive.test.TestHive
-
-class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils {
- val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli")
- val METASTORE_PATH = TestUtils.getMetastorePath("cli")
-
- override def beforeAll() {
- val pb = new ProcessBuilder(
- "../../bin/spark-sql",
- "--master",
- "local",
- "--hiveconf",
- s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
- "--hiveconf",
- "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH)
-
- process = pb.start()
- outputWriter = new PrintWriter(process.getOutputStream, true)
- inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
- errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
- waitForOutput(inputReader, "spark-sql>")
- }
-
- override def afterAll() {
- process.destroy()
- process.waitFor()
- }
-
- test("simple commands") {
- val dataFilePath = getDataFile("data/files/small_kv.txt")
- executeQuery("create table hive_test1(key int, val string);")
- executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;")
- executeQuery("cache table hive_test1", "Time taken")
- }
-}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
deleted file mode 100644
index 59f4952b78bc6..0000000000000
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import scala.collection.JavaConversions._
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent._
-
-import java.io.{BufferedReader, InputStreamReader}
-import java.sql.{Connection, DriverManager, Statement}
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.sql.Logging
-import org.apache.spark.sql.catalyst.util.getTempFilePath
-
-/**
- * Test for the HiveThriftServer2 using JDBC.
- */
-class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging {
-
- val WAREHOUSE_PATH = getTempFilePath("warehouse")
- val METASTORE_PATH = getTempFilePath("metastore")
-
- val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver"
- val TABLE = "test"
- // use a different port, than the hive standard 10000,
- // for tests to avoid issues with the port being taken on some machines
- val PORT = "10000"
-
- // If verbose is true, the test program will print all outputs coming from the Hive Thrift server.
- val VERBOSE = Option(System.getenv("SPARK_SQL_TEST_VERBOSE")).getOrElse("false").toBoolean
-
- Class.forName(DRIVER_NAME)
-
- override def beforeAll() { launchServer() }
-
- override def afterAll() { stopServer() }
-
- private def launchServer(args: Seq[String] = Seq.empty) {
- // Forking a new process to start the Hive Thrift server. The reason to do this is it is
- // hard to clean up Hive resources entirely, so we just start a new process and kill
- // that process for cleanup.
- val defaultArgs = Seq(
- "../../sbin/start-thriftserver.sh",
- "--master local",
- "--hiveconf",
- "hive.root.logger=INFO,console",
- "--hiveconf",
- s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true",
- "--hiveconf",
- s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH")
- val pb = new ProcessBuilder(defaultArgs ++ args)
- process = pb.start()
- inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
- errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
- waitForOutput(inputReader, "ThriftBinaryCLIService listening on")
-
- // Spawn a thread to read the output from the forked process.
- // Note that this is necessary since in some configurations, log4j could be blocked
- // if its output to stderr are not read, and eventually blocking the entire test suite.
- future {
- while (true) {
- val stdout = readFrom(inputReader)
- val stderr = readFrom(errorReader)
- if (VERBOSE && stdout.length > 0) {
- println(stdout)
- }
- if (VERBOSE && stderr.length > 0) {
- println(stderr)
- }
- Thread.sleep(50)
- }
- }
- }
-
- private def stopServer() {
- process.destroy()
- process.waitFor()
- }
-
- test("test query execution against a Hive Thrift server") {
- Thread.sleep(5 * 1000)
- val dataFilePath = getDataFile("data/files/small_kv.txt")
- val stmt = createStatement()
- stmt.execute("DROP TABLE IF EXISTS test")
- stmt.execute("DROP TABLE IF EXISTS test_cached")
- stmt.execute("CREATE TABLE test(key int, val string)")
- stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
- stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
- stmt.execute("CACHE TABLE test_cached")
-
- var rs = stmt.executeQuery("select count(*) from test")
- rs.next()
- assert(rs.getInt(1) === 5)
-
- rs = stmt.executeQuery("select count(*) from test_cached")
- rs.next()
- assert(rs.getInt(1) === 4)
-
- stmt.close()
- }
-
- def getConnection: Connection = {
- val connectURI = s"jdbc:hive2://localhost:$PORT/"
- DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
- }
-
- def createStatement(): Statement = getConnection.createStatement()
-}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
deleted file mode 100644
index bb2242618fbef..0000000000000
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import java.io.{BufferedReader, PrintWriter}
-import java.text.SimpleDateFormat
-import java.util.Date
-
-import org.apache.hadoop.hive.common.LogUtils
-import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
-
-object TestUtils {
- val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss")
-
- def getWarehousePath(prefix: String): String = {
- System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" +
- timestamp.format(new Date)
- }
-
- def getMetastorePath(prefix: String): String = {
- System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" +
- timestamp.format(new Date)
- }
-
- // Dummy function for initialize the log4j properties.
- def init() { }
-
- // initialize log4j
- try {
- LogUtils.initHiveLog4j()
- } catch {
- case e: LogInitializationException => // Ignore the error.
- }
-}
-
-trait TestUtils {
- var process : Process = null
- var outputWriter : PrintWriter = null
- var inputReader : BufferedReader = null
- var errorReader : BufferedReader = null
-
- def executeQuery(
- cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = {
- println("Executing: " + cmd + ", expecting output: " + outputMessage)
- outputWriter.write(cmd + "\n")
- outputWriter.flush()
- waitForQuery(timeout, outputMessage)
- }
-
- protected def waitForQuery(timeout: Long, message: String): String = {
- if (waitForOutput(errorReader, message, timeout)) {
- Thread.sleep(500)
- readOutput()
- } else {
- assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput())
- null
- }
- }
-
- // Wait for the specified str to appear in the output.
- protected def waitForOutput(
- reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = {
- val startTime = System.currentTimeMillis
- var out = ""
- while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) {
- out += readFrom(reader)
- }
- out.contains(str)
- }
-
- // Read stdout output and filter out garbage collection messages.
- protected def readOutput(): String = {
- val output = readFrom(inputReader)
- // Remove GC Messages
- val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC"))
- .mkString("\n")
- filteredOutput
- }
-
- protected def readFrom(reader: BufferedReader): String = {
- var out = ""
- var c = 0
- while (reader.ready) {
- c = reader.read()
- out += c.asInstanceOf[Char]
- }
- out
- }
-
- protected def getDataFile(name: String) = {
- Thread.currentThread().getContextClassLoader.getResource(name)
- }
-}
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 93d00f7c37c9b..1699ffe06ce15 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -32,7 +32,7 @@
Spark Project Hive
http://spark.apache.org/
- hive
+ hive
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 84d43eaeea51d..201c85f3d501e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -255,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
ShortType, DecimalType, TimestampType, BinaryType)
- protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
+ protected def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4395874526d51..e6ab68b563f8d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -925,7 +925,8 @@ private[hive] object HiveQl {
case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
- case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token(DIV(), left :: right:: Nil) =>
+ Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
/* Comparisons */
diff --git a/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
new file mode 100644
index 0000000000000..17ba0bea723c6
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
@@ -0,0 +1 @@
+0 0 0 1 2
diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
new file mode 100644
index 0000000000000..7b7a9175114ce
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
@@ -0,0 +1 @@
+2.0 0.5 0.3333333333333333 0.002
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 08ef4d9b6bb93..b4dbf2b115799 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -350,12 +350,6 @@ abstract class HiveComparisonTest
val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n")
- println("hive output")
- hive.foreach(println)
-
- println("catalyst printout")
- catalyst.foreach(println)
-
if (recomputeCache) {
logger.warn(s"Clearing cache files for failed test $testCaseName")
hiveCacheFiles.foreach(_.delete())
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 8489f2a34e63c..a8623b64c656f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -52,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT * FROM src WHERE key Between 1 and 2")
createQueryTest("div",
- "SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1")
+ "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1")
+
+ createQueryTest("division",
+ "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")
test("Query expressed in SQL") {
assert(sql("SELECT 1").collect() === Array(Seq(1)))
@@ -416,10 +419,10 @@ class HiveQuerySuite extends HiveComparisonTest {
hql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
- hql("set some.property=20")
- assert(get("some.property", "0") == "20")
- hql("set some.property = 40")
- assert(get("some.property", "0") == "40")
+ hql("set mapred.reduce.tasks=20")
+ assert(get("mapred.reduce.tasks", "0") == "20")
+ hql("set mapred.reduce.tasks = 40")
+ assert(get("mapred.reduce.tasks", "0") == "40")
hql(s"set $testKey=$testVal")
assert(get(testKey, "0") == testVal)
@@ -433,61 +436,63 @@ class HiveQuerySuite extends HiveComparisonTest {
val testKey = "spark.sql.key.usedfortestonly"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
+ def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ rdd.collect().map { case Row(key: String, value: String) => key -> value }.toSet
clear()
// "set" itself returns all config variables currently specified in SQLConf.
assert(hql("SET").collect().size == 0)
- assertResult(Array(s"$testKey=$testVal")) {
- hql(s"SET $testKey=$testVal").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(hql(s"SET $testKey=$testVal"))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Array(s"$testKey=$testVal")) {
- hql(s"SET $testKey=$testVal").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(hql("SET"))
}
hql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
- hql(s"SET").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ collectResults(hql("SET"))
}
// "set key"
- assertResult(Array(s"$testKey=$testVal")) {
- hql(s"SET $testKey").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(hql(s"SET $testKey"))
}
- assertResult(Array(s"$nonexistentKey=")) {
- hql(s"SET $nonexistentKey").collect().map(_.getString(0))
+ assertResult(Set(nonexistentKey -> "")) {
+ collectResults(hql(s"SET $nonexistentKey"))
}
// Assert that sql() should have the same effects as hql() by repeating the above using sql().
clear()
assert(sql("SET").collect().size == 0)
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql(s"SET $testKey=$testVal"))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Array(s"$testKey=$testVal")) {
- sql("SET").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql("SET"))
}
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Array(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
- sql("SET").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ collectResults(sql("SET"))
}
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql(s"SET $testKey"))
}
- assertResult(Array(s"$nonexistentKey=")) {
- sql(s"SET $nonexistentKey").collect().map(_.getString(0))
+ assertResult(Set(nonexistentKey -> "")) {
+ collectResults(sql(s"SET $nonexistentKey"))
}
clear()
diff --git a/streaming/pom.xml b/streaming/pom.xml
index b99f306b8f2cc..f60697ce745b7 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -28,7 +28,7 @@
org.apache.spark
spark-streaming_2.10
- streaming
+ streaming
jar
Spark Project Streaming
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index ce8316bb14891..d934b9cbfc3e8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -110,8 +110,7 @@ private[streaming] class ReceiverSupervisorImpl(
) {
val blockId = optionalBlockId.getOrElse(nextBlockId)
val time = System.currentTimeMillis
- blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]],
- storageLevel, tellMaster = true)
+ blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true)
logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata)
}
@@ -124,7 +123,7 @@ private[streaming] class ReceiverSupervisorImpl(
) {
val blockId = optionalBlockId.getOrElse(nextBlockId)
val time = System.currentTimeMillis
- blockManager.put(blockId, iterator, storageLevel, tellMaster = true)
+ blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true)
logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms")
reportPushedBlock(blockId, -1, optionalMetadata)
}
diff --git a/tools/pom.xml b/tools/pom.xml
index 97abb6b2b63e0..c0ee8faa7a615 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -27,7 +27,7 @@
org.apache.spark
spark-tools_2.10
- tools
+ tools
jar
Spark Project Tools
diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml
index 51744ece0412d..5b13a1f002d6e 100644
--- a/yarn/alpha/pom.xml
+++ b/yarn/alpha/pom.xml
@@ -24,7 +24,7 @@
../pom.xml
- yarn-alpha
+ yarn-alpha
org.apache.spark
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 3faaf053634d6..efb473aa1b261 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -29,7 +29,7 @@
pom
Spark Project YARN Parent POM
- yarn
+ yarn
diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml
index b6c8456d06684..ceaf9f9d71001 100644
--- a/yarn/stable/pom.xml
+++ b/yarn/stable/pom.xml
@@ -24,7 +24,7 @@
../pom.xml
- yarn-stable
+ yarn-stable
org.apache.spark