From 8fb3a65cbb714120d612e58ef9d12b0521a83260 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 14 Jul 2015 12:47:11 -0700 Subject: [PATCH 01/46] [SPARK-8911] Fix local mode endless heartbeats As of #7173 we expect executors to properly register with the driver before responding to their heartbeats. This behavior is not matched in local mode. This patch adds the missing event that needs to be posted. Author: Andrew Or Closes #7382 from andrewor14/fix-local-heartbeat and squashes the following commits: 1258bdf [Andrew Or] Post ExecutorAdded event to local executor --- .../spark/scheduler/local/LocalBackend.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 776e5d330e3c7..4d48fcfea44e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -25,7 +25,8 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} -import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo private case class ReviveOffers() @@ -50,8 +51,8 @@ private[spark] class LocalEndpoint( private var freeCores = totalCores - private val localExecutorId = SparkContext.DRIVER_IDENTIFIER - private val localExecutorHostname = "localhost" + val localExecutorId = SparkContext.DRIVER_IDENTIFIER + val localExecutorHostname = "localhost" private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) @@ -99,8 +100,9 @@ private[spark] class LocalBackend( extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis - var localEndpoint: RpcEndpointRef = null + private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) + private val listenerBus = scheduler.sc.listenerBus /** * Returns a list of URLs representing the user classpath. @@ -113,9 +115,13 @@ private[spark] class LocalBackend( } override def start() { - localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint( - "LocalBackendEndpoint", - new LocalEndpoint(SparkEnv.get.rpcEnv, userClassPath, scheduler, this, totalCores)) + val rpcEnv = SparkEnv.get.rpcEnv + val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) + localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + listenerBus.post(SparkListenerExecutorAdded( + System.currentTimeMillis, + executorEndpoint.localExecutorId, + new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) } override def stop() { From d267c2834a639aaebd0559355c6a82613abb689b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 14 Jul 2015 12:56:17 -0700 Subject: [PATCH 02/46] [SPARK-9031] Merge BlockObjectWriter and DiskBlockObject writer to remove abstract class BlockObjectWriter has only one concrete non-test class, DiskBlockObjectWriter. In order to simplify the code in preparation for other refactorings, I think that we should remove this base class and have only DiskBlockObjectWriter. While at one time we may have planned to have multiple BlockObjectWriter implementations, that doesn't seem to have happened, so the extra abstraction seems unnecessary. Author: Josh Rosen Closes #7391 from JoshRosen/shuffle-write-interface-refactoring and squashes the following commits: c418e33 [Josh Rosen] Fix compilation 5047995 [Josh Rosen] Fix comments d5dc548 [Josh Rosen] Update references in comments 89dc797 [Josh Rosen] Rename test suite. 5755918 [Josh Rosen] Remove unnecessary val in case class 1607c91 [Josh Rosen] Merge BlockObjectWriter and DiskBlockObjectWriter --- .../sort/BypassMergeSortShuffleWriter.java | 8 +- .../unsafe/UnsafeShuffleExternalSorter.java | 2 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 4 +- .../shuffle/FileShuffleBlockResolver.scala | 8 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../shuffle/hash/HashShuffleWriter.scala | 4 +- .../apache/spark/storage/BlockManager.scala | 2 +- ...iter.scala => DiskBlockObjectWriter.scala} | 96 +++++++------------ .../spark/util/collection/ChainedBuffer.scala | 2 +- .../util/collection/ExternalSorter.scala | 4 +- .../collection/PartitionedPairBuffer.scala | 1 - .../PartitionedSerializedPairBuffer.scala | 5 +- .../WritablePartitionedPairCollection.scala | 8 +- .../BypassMergeSortShuffleWriterSuite.scala | 4 +- ...scala => DiskBlockObjectWriterSuite.scala} | 2 +- ...PartitionedSerializedPairBufferSuite.scala | 52 +++++----- 16 files changed, 90 insertions(+), 114 deletions(-) rename core/src/main/scala/org/apache/spark/storage/{BlockObjectWriter.scala => DiskBlockObjectWriter.scala} (83%) rename core/src/test/scala/org/apache/spark/storage/{BlockObjectWriterSuite.scala => DiskBlockObjectWriterSuite.scala} (98%) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d3d6280284beb..0b8b604e18494 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final Serializer serializer; /** Array of file writers, one for each partition */ - private BlockObjectWriter[] partitionWriters; + private DiskBlockObjectWriter[] partitionWriters; public BypassMergeSortShuffleWriter( SparkConf conf, @@ -101,7 +101,7 @@ public void insertAll(Iterator> records) throws IOException { } final SerializerInstance serInstance = serializer.newInstance(); final long openStartTime = System.nanoTime(); - partitionWriters = new BlockObjectWriter[numPartitions]; + partitionWriters = new DiskBlockObjectWriter[numPartitions]; for (int i = 0; i < numPartitions; i++) { final Tuple2 tempShuffleBlockIdPlusFile = blockManager.diskBlockManager().createTempShuffleBlock(); @@ -121,7 +121,7 @@ public void insertAll(Iterator> records) throws IOException { partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } } @@ -169,7 +169,7 @@ public void stop() throws IOException { if (partitionWriters != null) { try { final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { // This method explicitly does _not_ throw exceptions: writer.revertPartialWritesAndClose(); if (!diskBlockManager.getFile(writer.blockId()).delete()) { diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 56289573209fb..1d460432be9ff 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -157,7 +157,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. - BlockObjectWriter writer; + DiskBlockObjectWriter writer; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index b8d66659804ad..71eed29563d4a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -26,7 +26,7 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; @@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter { private final File file; private final BlockId blockId; private final int numRecordsToWrite; - private BlockObjectWriter writer; + private DiskBlockObjectWriter writer; private int numRecordsSpilled = 0; public UnsafeSorterSpillWriter( diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6c3b3080d2605..f6a96d81e7aa9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] + val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ def releaseWriters(success: Boolean) @@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6e7bbb9..fae69551e7330 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index eb87cee15903c..41df70c602c30 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( shuffleBlockResolver: FileShuffleBlockResolver, @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1beafa1771448..86493673d958d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -648,7 +648,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala rename to core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 7eeabd1e0489c..49d9154f95a5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -26,66 +26,25 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils /** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. * - * This interface does not support concurrent writes. Also, once the writer has - * been opened, it cannot be reopened again. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes a key-value pair. - */ - def write(key: Any, value: Any) - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - def recordWritten() - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. */ private[spark] class DiskBlockObjectWriter( - blockId: BlockId, + val blockId: BlockId, file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who + // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ + extends OutputStream + with Logging { /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -122,7 +81,7 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 - override def open(): BlockObjectWriter = { + def open(): DiskBlockObjectWriter = { if (hasBeenClosed) { throw new IllegalStateException("Writer already closed. Cannot be reopened.") } @@ -159,9 +118,12 @@ private[spark] class DiskBlockObjectWriter( } } - override def isOpen: Boolean = objOut != null + def isOpen: Boolean = objOut != null - override def commitAndClose(): Unit = { + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { if (initialized) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. @@ -177,9 +139,15 @@ private[spark] class DiskBlockObjectWriter( commitAndCloseHasBeenCalled = true } - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. try { if (initialized) { writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) @@ -201,7 +169,10 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(key: Any, value: Any) { + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { if (!initialized) { open() } @@ -221,7 +192,10 @@ private[spark] class DiskBlockObjectWriter( bs.write(kvBytes, offs, len) } - override def recordWritten(): Unit = { + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -230,7 +204,11 @@ private[spark] class DiskBlockObjectWriter( } } - override def fileSegment(): FileSegment = { + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { if (!commitAndCloseHasBeenCalled) { throw new IllegalStateException( "fileSegment() is only valid after commitAndClose() has been called") diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index 516aaa44d03fc..ae60f3b0cb555 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) { private var _size: Long = 0 /** - * Feed bytes from this buffer into a BlockObjectWriter. + * Feed bytes from this buffer into a DiskBlockObjectWriter. * * @param pos Offset in the buffer to read from. * @param os OutputStream to read into. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 757dec66c203b..ba7ec834d622d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -30,7 +30,7 @@ import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} -import org.apache.spark.storage.{BlockId, BlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C]( // These variables are reset after each flush var objectsWritten: Long = 0 var spillMetrics: ShuffleWriteMetrics = null - var writer: BlockObjectWriter = null + var writer: DiskBlockObjectWriter = null def openWriter(): Unit = { assert (writer == null && spillMetrics == null) spillMetrics = new ShuffleWriteMetrics diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 04bb7fc78c13b..f5844d5353be7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,7 +19,6 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ae9a48729e201..87a786b02d651 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -21,9 +21,8 @@ import java.io.InputStream import java.nio.IntBuffer import java.util.Comparator -import org.apache.spark.SparkEnv import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ /** @@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( // current position in the meta buffer in ints var pos = 0 - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { val keyStart = getKeyStartPos(metaBuffer, pos) val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 7bc59898658e4..38848e9018c6c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that @@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection { } /** - * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: BlockObjectWriter): Unit + def writeNext(writer: DiskBlockObjectWriter): Unit def hasNext(): Boolean diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 542f8f45125a4..cc7342f1ecd78 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[BlockObjectWriter] { - override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + )).thenAnswer(new Answer[DiskBlockObjectWriter] { + override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments new DiskBlockObjectWriter( args(0).asInstanceOf[BlockId], diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala similarity index 98% rename from core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala rename to core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7bdea724fea58..66af6e1a79740 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { +class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index 6d2459d48d326..3b67f6206495a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.util.collection -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.Mockito.RETURNS_SMART_NULLS +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.Matchers._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{FileSegment, BlockObjectWriter} +import org.apache.spark.storage.DiskBlockObjectWriter class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { @@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { val struct = SomeStruct("something", 5) buffer.insert(4, 10, struct) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) stream.readObject[AnyRef]() should be (10) stream.readObject[AnyRef]() should be (struct) } @@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { buffer.insert(5, 3, struct3) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) @@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) val iter = stream.asIterator iter.next() should be (2) iter.next() should be (struct2) @@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { iter.next() should be (struct1) assert(!iter.hasNext) } -} - -case class SomeStruct(val str: String, val num: Int) - -class SimpleBlockObjectWriter extends BlockObjectWriter(null) { - val baos = new ByteArrayOutputStream() - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - baos.write(bytes, offs, len) + def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { + val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) + val baos = new ByteArrayOutputStream() + when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val args = invocationOnMock.getArguments + val bytes = args(0).asInstanceOf[Array[Byte]] + val offset = args(1).asInstanceOf[Int] + val length = args(2).asInstanceOf[Int] + baos.write(bytes, offset, length) + } + }) + (writer, baos) } - - def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray) - - override def open(): BlockObjectWriter = this - override def close(): Unit = { } - override def isOpen: Boolean = true - override def commitAndClose(): Unit = { } - override def revertPartialWritesAndClose(): Unit = { } - override def fileSegment(): FileSegment = null - override def write(key: Any, value: Any): Unit = { } - override def recordWritten(): Unit = { } - override def write(b: Int): Unit = { } } + +case class SomeStruct(str: String, num: Int) From 0a4071eab30db1db80f61ed2cb2e7243291183ce Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Tue, 14 Jul 2015 13:14:47 -0700 Subject: [PATCH 03/46] [SPARK-8718] [GRAPHX] Improve EdgePartition2D for non perfect square number of partitions See https://github.com/aray/e2d/blob/master/EdgePartition2D.ipynb Author: Andrew Ray Closes #7104 from aray/edge-partition-2d-improvement and squashes the following commits: 3729f84 [Andrew Ray] correct bounds and remove unneeded comments 97f8464 [Andrew Ray] change less 5141ab4 [Andrew Ray] Merge branch 'master' into edge-partition-2d-improvement 925fd2c [Andrew Ray] use new interface for partitioning 001bfd0 [Andrew Ray] Refactor PartitionStrategy so that we can return a prtition function for a given number of parts. To keep compatibility we define default methods that translate between the two implementation options. Made EdgePartition2D use old strategy when we have a perfect square and implement new interface. 5d42105 [Andrew Ray] % -> / 3560084 [Andrew Ray] Merge branch 'master' into edge-partition-2d-improvement f006364 [Andrew Ray] remove unneeded comments cfa2c5e [Andrew Ray] Modifications to EdgePartition2D so that it works for non perfect squares. --- .../spark/graphx/PartitionStrategy.scala | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 7372dfbd9fe98..70a7592da8ae3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,7 +32,7 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. * * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: @@ -61,26 +61,36 @@ object PartitionStrategy { * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, * P6)` or the last * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be - * replicated to at most `2 * sqrt(numParts) - 1` machines. + * replicated to at most `2 * sqrt(numParts)` machines. * * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the * vertex locations. * - * One of the limitations of this approach is that the number of machines must either be a - * perfect square. We partially address this limitation by computing the machine assignment to - * the next - * largest perfect square and then mapping back down to the actual number of machines. - * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect - * square is used. + * When the number of partitions requested is not a perfect square we use a slightly different + * method where the last column can have a different number of rows than the others while still + * maintaining the same size per block. */ case object EdgePartition2D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt - (col * ceilSqrtNumParts + row) % numParts + if (numParts == ceilSqrtNumParts * ceilSqrtNumParts) { + // Use old method for perfect squared to ensure we get same results + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + + } else { + // Otherwise use new method + val cols = ceilSqrtNumParts + val rows = (numParts + cols - 1) / cols + val lastColRows = numParts - rows * (cols - 1) + val col = (math.abs(src * mixingPrime) % numParts / rows).toInt + val row = (math.abs(dst * mixingPrime) % (if (col < cols - 1) rows else lastColRows)).toInt + col * rows + row + + } } } From fb1d06fc242ec00320f1a3049673fbb03c4a6eb9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 14 Jul 2015 13:58:36 -0700 Subject: [PATCH 04/46] [SPARK-4072] [CORE] Display Streaming blocks in Streaming UI Replace #6634 This PR adds `SparkListenerBlockUpdated` to SparkListener so that it can monitor all block update infos that are sent to `BlockManagerMasaterEndpoint`, and also add new tables in the Storage tab to display the stream block infos. ![screen shot 2015-07-01 at 5 19 46 pm](https://cloud.githubusercontent.com/assets/1000778/8451562/c291a6ec-2016-11e5-890d-0afc174e1f8c.png) Author: zsxwing Closes #6672 from zsxwing/SPARK-4072-2 and squashes the following commits: df2c1d8 [zsxwing] Use xml query to check the xml elements 54d54af [zsxwing] Add unit tests for StoragePage e29fb53 [zsxwing] Update as per TD's comments ccbee07 [zsxwing] Fix the code style 6dc42b4 [zsxwing] Fix the replication level of blocks 450fad1 [zsxwing] Merge branch 'master' into SPARK-4072-2 1e9ef52 [zsxwing] Don't categorize by Executor ID ca0ab69 [zsxwing] Fix the code style 3de2762 [zsxwing] Make object BlockUpdatedInfo private e95b594 [zsxwing] Add 'Aggregated Stream Block Metrics by Executor' table ba5d0d1 [zsxwing] Refactor the unit test to improve the readability 4bbe341 [zsxwing] Revert JsonProtocol and don't log SparkListenerBlockUpdated b464dd1 [zsxwing] Add onBlockUpdated to EventLoggingListener 5ba014c [zsxwing] Fix the code style 0b1e47b [zsxwing] Add a developer api BlockUpdatedInfo 04838a9 [zsxwing] Fix the code style 2baa161 [zsxwing] Add unit tests 80f6c6d [zsxwing] Address comments 797ee4b [zsxwing] Display Streaming blocks in Streaming UI --- .../org/apache/spark/JavaSparkListener.java | 22 +- .../apache/spark/SparkFirehoseListener.java | 6 + .../scheduler/EventLoggingListener.scala | 3 + .../spark/scheduler/SparkListener.scala | 10 +- .../spark/scheduler/SparkListenerBus.scala | 2 + .../storage/BlockManagerMasterEndpoint.scala | 3 +- .../spark/storage/BlockStatusListener.scala | 105 ++++++++ .../spark/storage/BlockUpdatedInfo.scala | 47 ++++ .../scala/org/apache/spark/ui/UIUtils.scala | 14 +- .../apache/spark/ui/storage/StoragePage.scala | 148 ++++++++++- .../apache/spark/ui/storage/StorageTab.scala | 3 +- .../storage/BlockStatusListenerSuite.scala | 119 +++++++++ .../spark/ui/storage/StoragePageSuite.scala | 230 ++++++++++++++++++ 13 files changed, 684 insertions(+), 28 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 646496f313507..fa9acf0a15b88 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -17,23 +17,7 @@ package org.apache.spark; -import org.apache.spark.scheduler.SparkListener; -import org.apache.spark.scheduler.SparkListenerApplicationEnd; -import org.apache.spark.scheduler.SparkListenerApplicationStart; -import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; -import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; -import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorAdded; -import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; -import org.apache.spark.scheduler.SparkListenerExecutorRemoved; -import org.apache.spark.scheduler.SparkListenerJobEnd; -import org.apache.spark.scheduler.SparkListenerJobStart; -import org.apache.spark.scheduler.SparkListenerStageCompleted; -import org.apache.spark.scheduler.SparkListenerStageSubmitted; -import org.apache.spark.scheduler.SparkListenerTaskEnd; -import org.apache.spark.scheduler.SparkListenerTaskGettingResult; -import org.apache.spark.scheduler.SparkListenerTaskStart; -import org.apache.spark.scheduler.SparkListenerUnpersistRDD; +import org.apache.spark.scheduler.*; /** * Java clients should extend this class instead of implementing @@ -94,4 +78,8 @@ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } @Override public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index fbc5666959055..1214d05ba6063 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -112,4 +112,10 @@ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { onEvent(executorRemoved); } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 62b05033a9281..5a06ef02f5c57 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -199,6 +199,9 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + // No-op because logging every update would be overkill + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 9620915f495ab..896f1743332f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @@ -98,6 +98,9 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -215,6 +218,11 @@ trait SparkListener { * Called when the driver removes an executor. */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + + /** + * Called when the driver receives a block update info. + */ + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 61e69ecc08387..04afde33f5aad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -58,6 +58,8 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case blockUpdated: SparkListenerBlockUpdated => + listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 68ed9096731c5..5dc0c537cbb62 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -60,10 +60,11 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case UpdateBlockInfo( + case _updateBlockInfo @ UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => context.reply(updateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => context.reply(getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala new file mode 100644 index 0000000000000..2789e25b8d3ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +private[spark] case class BlockUIData( + blockId: BlockId, + location: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +/** + * The aggregated status of stream blocks in an executor + */ +private[spark] case class ExecutorStreamBlockStatus( + executorId: String, + location: String, + blocks: Seq[BlockUIData]) { + + def totalMemSize: Long = blocks.map(_.memSize).sum + + def totalDiskSize: Long = blocks.map(_.diskSize).sum + + def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum + + def numStreamBlocks: Int = blocks.size + +} + +private[spark] class BlockStatusListener extends SparkListener { + + private val blockManagers = + new mutable.HashMap[BlockManagerId, mutable.HashMap[BlockId, BlockUIData]] + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val blockId = blockUpdated.blockUpdatedInfo.blockId + if (!blockId.isInstanceOf[StreamBlockId]) { + // Now we only monitor StreamBlocks + return + } + val blockManagerId = blockUpdated.blockUpdatedInfo.blockManagerId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize + + synchronized { + // Drop the update info if the block manager is not registered + blockManagers.get(blockManagerId).foreach { blocksInBlockManager => + if (storageLevel.isValid) { + blocksInBlockManager.put(blockId, + BlockUIData( + blockId, + blockManagerId.hostPort, + storageLevel, + memSize, + diskSize, + externalBlockStoreSize) + ) + } else { + // If isValid is not true, it means we should drop the block. + blocksInBlockManager -= blockId + } + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + synchronized { + blockManagers.put(blockManagerAdded.blockManagerId, mutable.HashMap()) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = synchronized { + blockManagers -= blockManagerRemoved.blockManagerId + } + + def allExecutorStreamBlockStatus: Seq[ExecutorStreamBlockStatus] = synchronized { + blockManagers.map { case (blockManagerId, blocks) => + ExecutorStreamBlockStatus( + blockManagerId.executorId, blockManagerId.hostPort, blocks.values.toSeq) + }.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala new file mode 100644 index 0000000000000..a5790e4454a89 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.BlockManagerMessages.UpdateBlockInfo + +/** + * :: DeveloperApi :: + * Stores information about a block status in a block manager. + */ +@DeveloperApi +case class BlockUpdatedInfo( + blockManagerId: BlockManagerId, + blockId: BlockId, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + externalBlockStoreSize: Long) + +private[spark] object BlockUpdatedInfo { + + private[spark] def apply(updateBlockInfo: UpdateBlockInfo): BlockUpdatedInfo = { + BlockUpdatedInfo( + updateBlockInfo.blockManagerId, + updateBlockInfo.blockId, + updateBlockInfo.storageLevel, + updateBlockInfo.memSize, + updateBlockInfo.diskSize, + updateBlockInfo.externalBlockStoreSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7898039519201..718aea7e1dc22 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. @@ -267,9 +267,17 @@ private[spark] object UIUtils extends Logging { fixedWidth: Boolean = false, id: Option[String] = None, headerClasses: Seq[String] = Seq.empty, - stripeRowsWithCss: Boolean = true): Seq[Node] = { + stripeRowsWithCss: Boolean = true, + sortable: Boolean = true): Seq[Node] = { - val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + val listingTableClass = { + val _tableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED + if (sortable) { + _tableClass + " sortable" + } else { + _tableClass + } + } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 07db783c572cf..04f584621e71e 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.RDDInfo +import org.apache.spark.storage._ import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -30,13 +30,25 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val rdds = listener.rddInfoList - val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table")) + val content = rddTable(listener.rddInfoList) ++ + receiverBlockTables(listener.allExecutorStreamBlockStatus.sortBy(_.executorId)) UIUtils.headerSparkPage("Storage", content, parent) } + private[storage] def rddTable(rdds: Seq[RDDInfo]): Seq[Node] = { + if (rdds.isEmpty) { + // Don't show the rdd table if there is no RDD persisted. + Nil + } else { +
+

RDDs

+ {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
+ } + } + /** Header fields for the RDD table */ - private def rddHeader = Seq( + private val rddHeader = Seq( "RDD Name", "Storage Level", "Cached Partitions", @@ -56,7 +68,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.storageLevel.description} - {rdd.numCachedPartitions} + {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)} {Utils.bytesToString(rdd.externalBlockStoreSize)} @@ -64,4 +76,130 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:on } + + private[storage] def receiverBlockTables(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { + if (statuses.map(_.numStreamBlocks).sum == 0) { + // Don't show the tables if there is no stream block + Nil + } else { + val blocks = statuses.flatMap(_.blocks).groupBy(_.blockId).toSeq.sortBy(_._1.toString) + +
+

Receiver Blocks

+ {executorMetricsTable(statuses)} + {streamBlockTable(blocks)} +
+ } + } + + private def executorMetricsTable(statuses: Seq[ExecutorStreamBlockStatus]): Seq[Node] = { +
+
Aggregated Block Metrics by Executor
+ {UIUtils.listingTable(executorMetricsTableHeader, executorMetricsTableRow, statuses, + id = Some("storage-by-executor-stream-blocks"))} +
+ } + + private val executorMetricsTableHeader = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + + private def executorMetricsTableRow(status: ExecutorStreamBlockStatus): Seq[Node] = { + + + {status.executorId} + + + {status.location} + + + {Utils.bytesToString(status.totalMemSize)} + + + {Utils.bytesToString(status.totalExternalBlockStoreSize)} + + + {Utils.bytesToString(status.totalDiskSize)} + + + {status.numStreamBlocks.toString} + + + } + + private def streamBlockTable(blocks: Seq[(BlockId, Seq[BlockUIData])]): Seq[Node] = { + if (blocks.isEmpty) { + Nil + } else { +
+
Blocks
+ {UIUtils.listingTable( + streamBlockTableHeader, + streamBlockTableRow, + blocks, + id = Some("storage-by-block-table"), + sortable = false)} +
+ } + } + + private val streamBlockTableHeader = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + + /** Render a stream block */ + private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { + val replications = block._2 + assert(replications.size > 0) // This must be true because it's the result of "groupBy" + if (replications.size == 1) { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) + } else { + streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++ + replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten + } + } + + private def streamBlockTableSubrow( + blockId: BlockId, block: BlockUIData, replication: Int, firstSubrow: Boolean): Seq[Node] = { + val (storageLevel, size) = streamBlockStorageLevelDescriptionAndSize(block) + + + { + if (firstSubrow) { + + {block.blockId.toString} + + + {replication.toString} + + } + } + {block.location} + {storageLevel} + {Utils.bytesToString(size)} + + } + + private[storage] def streamBlockStorageLevelDescriptionAndSize( + block: BlockUIData): (String, Long) = { + if (block.storageLevel.useDisk) { + ("Disk", block.diskSize) + } else if (block.storageLevel.useMemory && block.storageLevel.deserialized) { + ("Memory", block.memSize) + } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { + ("Memory Serialized", block.memSize) + } else if (block.storageLevel.useOffHeap) { + ("External", block.externalBlockStoreSize) + } else { + throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 0351749700962..22e2993b3b5bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,7 +39,8 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi -class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { +class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { + private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala new file mode 100644 index 0000000000000..d7ffde1e7864e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler._ + +class BlockStatusListenerSuite extends SparkFunSuite { + + test("basic functions") { + val blockManagerId = BlockManagerId("0", "localhost", 10000) + val listener = new BlockStatusListener() + + // Add a block manager and a new block status + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId, 0)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The new block status should be added to the listener + val expectedBlock = BlockUIData( + StreamBlockId(0, 100), + "localhost:10000", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + val expectedExecutorStreamBlockStatus = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus) + + // Add the second block manager + val blockManagerId2 = BlockManagerId("1", "localhost", 10001) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(0, blockManagerId2, 0)) + // Add a new replication of the same block id from the second manager + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + val expectedBlock2 = BlockUIData( + StreamBlockId(0, 100), + "localhost:10001", + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0 + ) + // Each block manager should contain one block + val expectedExecutorStreamBlockStatus2 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq(expectedBlock2)) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus2) + + // Remove a replication of the same block + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.NONE, // StorageLevel.NONE means removing it + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 0))) + // Only the first block manager contains a block + val expectedExecutorStreamBlockStatus3 = Set( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), + ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + ) + assert(listener.allExecutorStreamBlockStatus.toSet === expectedExecutorStreamBlockStatus3) + + // Remove the second block manager at first but add a new block status + // from this removed block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId2)) + listener.onBlockUpdated(SparkListenerBlockUpdated( + BlockUpdatedInfo( + blockManagerId2, + StreamBlockId(0, 100), + StorageLevel.MEMORY_AND_DISK, + memSize = 100, + diskSize = 100, + externalBlockStoreSize = 0))) + // The second block manager is removed so we should not see the new block + val expectedExecutorStreamBlockStatus4 = Seq( + ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) + ) + assert(listener.allExecutorStreamBlockStatus === expectedExecutorStreamBlockStatus4) + + // Remove the last block manager + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(0, blockManagerId)) + // No block manager now so we should dop all block managers + assert(listener.allExecutorStreamBlockStatus.isEmpty) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala new file mode 100644 index 0000000000000..3dab15a9d4691 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import scala.xml.Utility + +import org.mockito.Mockito._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage._ + +class StoragePageSuite extends SparkFunSuite { + + val storageTab = mock(classOf[StorageTab]) + when(storageTab.basePath).thenReturn("http://localhost:4040") + val storagePage = new StoragePage(storageTab) + + test("rddTable") { + val rdd1 = new RDDInfo(1, + "rdd1", + 10, + StorageLevel.MEMORY_ONLY, + Seq.empty) + rdd1.memSize = 100 + rdd1.numCachedPartitions = 10 + + val rdd2 = new RDDInfo(2, + "rdd2", + 10, + StorageLevel.DISK_ONLY, + Seq.empty) + rdd2.diskSize = 200 + rdd2.numCachedPartitions = 5 + + val rdd3 = new RDDInfo(3, + "rdd3", + 10, + StorageLevel.MEMORY_AND_DISK_SER, + Seq.empty) + rdd3.memSize = 400 + rdd3.diskSize = 500 + rdd3.numCachedPartitions = 10 + + val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + + val headers = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size in ExternalBlockStore", + "Size on Disk") + assert((xmlNodes \\ "th").map(_.text) === headers) + + assert((xmlNodes \\ "tr").size === 3) + assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=1")) + + assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=2")) + + assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B", + "500.0 B")) + // Check the url + assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === + Some("http://localhost:4040/storage/rdd?id=3")) + } + + test("empty rddTable") { + assert(storagePage.rddTable(Seq.empty).isEmpty) + } + + test("streamBlockStorageLevelDescriptionAndSize") { + val memoryBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) + + val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + assert(("Memory Serialized", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) + + val diskBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) + + val externalBlock = BlockUIData(StreamBlockId(0, 0), + "localhost:1111", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 100) + assert(("External", 100) === + storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock)) + } + + test("receiverBlockTables") { + val blocksForExecutor0 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10000", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(1, 1), + "localhost:10000", + StorageLevel.DISK_ONLY, + memSize = 0, + diskSize = 100, + externalBlockStoreSize = 0) + ) + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) + + val blocksForExecutor1 = Seq( + BlockUIData(StreamBlockId(0, 0), + "localhost:10001", + StorageLevel.MEMORY_ONLY, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0), + BlockUIData(StreamBlockId(2, 2), + "localhost:10001", + StorageLevel.OFF_HEAP, + memSize = 0, + diskSize = 0, + externalBlockStoreSize = 200), + BlockUIData(StreamBlockId(1, 1), + "localhost:10001", + StorageLevel.MEMORY_ONLY_SER, + memSize = 100, + diskSize = 0, + externalBlockStoreSize = 0) + ) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) + val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) + + val executorTable = (xmlNodes \\ "table")(0) + val executorHeaders = Seq( + "Executor ID", + "Address", + "Total Size in Memory", + "Total Size in ExternalBlockStore", + "Total Size on Disk", + "Stream Blocks") + assert((executorTable \\ "th").map(_.text) === executorHeaders) + + assert((executorTable \\ "tr").size === 2) + assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2")) + assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3")) + + val blockTable = (xmlNodes \\ "table")(1) + val blockHeaders = Seq( + "Block ID", + "Replication Level", + "Location", + "Storage Level", + "Size") + assert((blockTable \\ "th").map(_.text) === blockHeaders) + + assert((blockTable \\ "tr").size === 5) + assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) === + Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(0) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(0) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(1) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory", "100.0 B")) + + assert(((blockTable \\ "tr")(2) \\ "td").map(_.text.trim) === + Seq("input-1-1", "2", "localhost:10000", "Disk", "100.0 B")) + // Check "rowspan=2" for the first 2 columns + assert(((blockTable \\ "tr")(2) \\ "td")(0).attribute("rowspan").map(_.text) === Some("2")) + assert(((blockTable \\ "tr")(2) \\ "td")(1).attribute("rowspan").map(_.text) === Some("2")) + + assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) === + Seq("localhost:10001", "Memory Serialized", "100.0 B")) + + assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) === + Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B")) + // Check "rowspan=1" for the first 2 columns + assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1")) + assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1")) + } + + test("empty receiverBlockTables") { + assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) + + val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) + val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) + assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) + } +} From 4b5cfc988f23988c2334882a255d494fc93d252e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 Jul 2015 14:19:27 -0700 Subject: [PATCH 05/46] [SPARK-8800] [SQL] Fix inaccurate precision/scale of Decimal division operation JIRA: https://issues.apache.org/jira/browse/SPARK-8800 Previously, we turn to Java BigDecimal's divide with specified ROUNDING_MODE to avoid non-terminating decimal expansion problem. However, as JihongMA reported, for the division operation on some specific values, we get inaccurate results. Author: Liang-Chi Hsieh Closes #7212 from viirya/fix_decimal4 and squashes the following commits: 4205a0a [Liang-Chi Hsieh] Fix inaccuracy precision/scale of Decimal division operation. --- .../scala/org/apache/spark/sql/types/Decimal.scala | 14 +++++++++++--- .../spark/sql/types/decimal/DecimalSuite.scala | 10 +++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 5a169488c97eb..f5bd068d60dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -145,6 +145,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + def toLimitedBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() def toUnscaledLong: Long = { @@ -269,9 +277,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (that.isZero) { null } else { - // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide - // with specified ROUNDING_MODE. - Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id)) + // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited + // precision and scala. + Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 5f312964e5bf7..030bb6d21b18b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -170,6 +170,14 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("fix non-terminating decimal expansion problem") { val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) - assert(decimal.toString === "0.333") + // The difference between decimal should not be more than 0.001. + assert(decimal.toDouble - 0.333 < 0.001) + } + + test("fix loss of precision/scale when doing division operation") { + val a = Decimal(2) / Decimal(3) + assert(a.toDouble < 1.0 && a.toDouble > 0.6) + val b = Decimal(1) / Decimal(8) + assert(b.toDouble === 0.125) } } From 740b034f1ca885a386f5a9ef7e0c81c714b047ff Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 14 Jul 2015 22:44:54 +0100 Subject: [PATCH 06/46] [SPARK-4362] [MLLIB] Make prediction probability available in NaiveBayesModel Add predictProbabilities to Naive Bayes, return class probabilities. Continues https://github.com/apache/spark/pull/6761 Author: Sean Owen Closes #7376 from srowen/SPARK-4362 and squashes the following commits: 23d5a76 [Sean Owen] Fix model.labels -> model.theta 95d91fb [Sean Owen] Check that predicted probabilities sum to 1 b32d1c8 [Sean Owen] Add predictProbabilities to Naive Bayes, return class probabilities --- .../mllib/classification/NaiveBayes.scala | 76 +++++++++++++++---- .../classification/NaiveBayesSuite.scala | 55 +++++++++++++- 2 files changed, 113 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f51ee36d0dfcb..9e379d7d74b2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { case Multinomial => - val prob = thetaMatrix.multiply(testData) - BLAS.axpy(1.0, piVector, prob) - labels(prob.argmax) + labels(multinomialCalculation(testData).argmax) case Bernoulli => - testData.foreachActive { (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") - } - } - val prob = thetaMinusNegTheta.get.multiply(testData) - BLAS.axpy(1.0, piVector, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) - labels(prob.argmax) - case _ => - // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") + labels(bernoulliCalculation(testData).argmax) + } + } + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities, + * in the same order as class labels + */ + def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { + val bcModel = testData.context.broadcast(this) + testData.mapPartitions { iter => + val model = bcModel.value + iter.map(model.predictProbabilities) } } + /** + * Predict posterior class probabilities for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return predicted posterior class probabilities from the trained model, + * in the same order as class labels + */ + def predictProbabilities(testData: Vector): Vector = { + modelType match { + case Multinomial => + posteriorProbabilities(multinomialCalculation(testData)) + case Bernoulli => + posteriorProbabilities(bernoulliCalculation(testData)) + } + } + + private def multinomialCalculation(testData: Vector) = { + val prob = thetaMatrix.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + prob + } + + private def bernoulliCalculation(testData: Vector) = { + testData.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + private def posteriorProbabilities(logProb: DenseVector) = { + val logProbArray = logProb.toArray + val maxLog = logProbArray.max + val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog)) + val probSum = scaledProbs.sum + new DenseVector(scaledProbs.map(_ / probSum)) + } + override def save(sc: SparkContext, path: String): Unit = { val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index f7fc8730606af..cffa1ab700f80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils object NaiveBayesSuite { @@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + + // Test posteriors + validationData.map(_.features).foreach { features => + val predicted = model.predictProbabilities(features).toArray + assert(predicted.sum ~== 1.0 relTol 1.0e-10) + val expected = expectedMultinomialProbabilities(model, features) + expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) } + } + } + + /** + * @param model Multinomial Naive Bayes model + * @param testData input to compute posterior probabilities for + * @return posterior class probabilities (in order of labels) for input + */ + private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = { + val piVector = new BDV(model.pi) + // model.theta is row-major; treat it as col-major representation of transpose, and transpose: + val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t + val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze) + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + classProbs.map(_ / classProbsSum) } test("Naive Bayes Bernoulli") { @@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + + // Test posteriors + validationData.map(_.features).foreach { features => + val predicted = model.predictProbabilities(features).toArray + assert(predicted.sum ~== 1.0 relTol 1.0e-10) + val expected = expectedBernoulliProbabilities(model, features) + expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) } + } + } + + /** + * @param model Bernoulli Naive Bayes model + * @param testData input to compute posterior probabilities for + * @return posterior class probabilities (in order of labels) for input + */ + private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = { + val piVector = new BDV(model.pi) + val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t + val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length, + model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t + val testBreeze = testData.toBreeze + val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze + val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze) + val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze) + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + classProbs.map(_ / classProbsSum) } test("detect negative values") { From 11e5c372862ec00e57460b37ccfee51c6d93c5f7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 14 Jul 2015 16:08:17 -0700 Subject: [PATCH 07/46] [SPARK-8962] Add Scalastyle rule to ban direct use of Class.forName; fix existing uses This pull request adds a Scalastyle regex rule which fails the style check if `Class.forName` is used directly. `Class.forName` always loads classes from the default / system classloader, but in a majority of cases, we should be using Spark's own `Utils.classForName` instead, which tries to load classes from the current thread's context classloader and falls back to the classloader which loaded Spark when the context classloader is not defined. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7350) Author: Josh Rosen Closes #7350 from JoshRosen/ban-Class.forName and squashes the following commits: e3e96f7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into ban-Class.forName c0b7885 [Josh Rosen] Hopefully fix the last two cases d707ba7 [Josh Rosen] Fix uses of Class.forName that I missed in my first cleanup pass 046470d [Josh Rosen] Merge remote-tracking branch 'origin/master' into ban-Class.forName 62882ee [Josh Rosen] Fix uses of Class.forName or add exclusion. d9abade [Josh Rosen] Add stylechecker rule to ban uses of Class.forName --- .../main/scala/org/apache/spark/Logging.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 11 +++++------ .../main/scala/org/apache/spark/SparkEnv.scala | 2 +- .../apache/spark/api/r/RBackendHandler.scala | 18 ++---------------- .../spark/broadcast/BroadcastManager.scala | 3 ++- .../apache/spark/deploy/SparkHadoopUtil.scala | 4 ++-- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../spark/deploy/SparkSubmitArguments.scala | 2 +- .../spark/deploy/history/HistoryServer.scala | 2 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../rest/SubmitRestProtocolMessage.scala | 2 +- .../spark/deploy/worker/DriverWrapper.scala | 2 +- .../spark/deploy/worker/WorkerArguments.scala | 2 ++ .../org/apache/spark/executor/Executor.scala | 2 +- .../org/apache/spark/io/CompressionCodec.scala | 3 +-- .../spark/mapred/SparkHadoopMapRedUtil.scala | 5 +++-- .../mapreduce/SparkHadoopMapReduceUtil.scala | 9 +++++---- .../apache/spark/metrics/MetricsSystem.scala | 6 ++++-- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 6 +++--- .../scala/org/apache/spark/rpc/RpcEnv.scala | 3 +-- .../spark/serializer/JavaSerializer.scala | 5 ++++- .../spark/serializer/KryoSerializer.scala | 2 ++ .../serializer/SerializationDebugger.scala | 2 ++ .../spark/storage/ExternalBlockStore.scala | 2 +- .../org/apache/spark/util/ClosureCleaner.scala | 2 ++ .../org/apache/spark/util/SizeEstimator.scala | 2 ++ .../scala/org/apache/spark/util/Utils.scala | 11 +++++++++-- .../scala/org/apache/spark/FileSuite.scala | 2 ++ .../SparkContextSchedulerCreationSuite.scala | 3 ++- .../apache/spark/deploy/SparkSubmitSuite.scala | 4 ++-- .../org/apache/spark/rdd/JdbcRDDSuite.scala | 3 ++- .../KryoSerializerDistributedSuite.scala | 2 ++ .../util/MutableURLClassLoaderSuite.scala | 2 ++ .../spark/streaming/flume/sink/Logging.scala | 2 ++ .../spark/graphx/util/BytecodeUtils.scala | 2 +- .../org/apache/spark/repl/SparkIMain.scala | 2 ++ scalastyle-config.xml | 11 +++++++++++ .../org/apache/spark/sql/types/DataType.scala | 3 ++- .../org/apache/spark/sql/SQLContext.scala | 3 +-- .../spark/sql/parquet/ParquetRelation.scala | 7 ++++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 3 ++- .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 3 ++- .../thriftserver/HiveThriftServer2Suites.scala | 2 +- .../apache/spark/sql/hive/TableReader.scala | 4 +--- .../spark/sql/hive/client/ClientWrapper.scala | 9 ++++----- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 8 ++++---- .../streaming/scheduler/JobGenerator.scala | 6 +++--- .../spark/tools/GenerateMIMAIgnore.scala | 2 ++ .../org/apache/spark/deploy/yarn/Client.scala | 4 ++-- 49 files changed, 117 insertions(+), 84 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 87ab099267b2f..f0598816d6c07 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -159,7 +159,7 @@ private object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. - val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 82704b1ab2189..bd1cc332a63e7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1968,7 +1968,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli for (className <- listenerClassNames) { // Use reflection to find the right constructor val constructors = { - val listenerClass = Class.forName(className) + val listenerClass = Utils.classForName(className) listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] } val constructorTakingSparkConf = constructors.find { c => @@ -2503,7 +2503,7 @@ object SparkContext extends Logging { "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") } val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { @@ -2515,7 +2515,7 @@ object SparkContext extends Logging { } val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { @@ -2528,8 +2528,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { - val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") + val clazz = Utils.classForName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] @@ -2541,7 +2540,7 @@ object SparkContext extends Logging { val backend = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + Utils.classForName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index d18fc599e9890..adfece4d6e7c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -261,7 +261,7 @@ object SparkEnv extends Logging { // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + val cls = Utils.classForName(className) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 4b8f7fe9242e0..9658e9a696ffa 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -26,6 +26,7 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ +import org.apache.spark.util.Utils /** * Handler for RBackend @@ -88,21 +89,6 @@ private[r] class RBackendHandler(server: RBackend) ctx.close() } - // Looks up a class given a class name. This function first checks the - // current class loader and if a class is not found, it looks up the class - // in the context class loader. Address [SPARK-5185] - def getStaticClass(objId: String): Class[_] = { - try { - val clsCurrent = Class.forName(objId) - clsCurrent - } catch { - // Use contextLoader if we can't find the JAR in the system class loader - case e: ClassNotFoundException => - val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) - clsContext - } - } - def handleMethodCall( isStatic: Boolean, objId: String, @@ -113,7 +99,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - getStaticClass(objId) + Utils.classForName(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 685313ac009ba..fac6666bb3410 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class BroadcastManager( val isDriver: Boolean, @@ -42,7 +43,7 @@ private[spark] class BroadcastManager( conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 6d14590a1d192..9f94118829ff1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -178,7 +178,7 @@ class SparkHadoopUtil extends Logging { private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { val statisticsDataClass = - Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") statisticsDataClass.getDeclaredMethod(methodName) } @@ -356,7 +356,7 @@ object SparkHadoopUtil { System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") .newInstance() .asInstanceOf[SparkHadoopUtil] } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7089a7e26707f..036cb6e054791 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -624,7 +624,7 @@ object SparkSubmit { var mainClass: Class[_] = null try { - mainClass = Class.forName(childMainClass, true, loader) + mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index ebb39c354dff1..b3710073e330c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -576,7 +576,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setSecurityManager(sm) try { - Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + Utils.classForName(mainClass).getMethod("main", classOf[Array[String]]) .invoke(null, Array(HELP)) } catch { case e: InvocationTargetException => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 10638afb74900..a076a9c3f984d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -228,7 +228,7 @@ object HistoryServer extends Logging { val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Class.forName(providerName) + val provider = Utils.classForName(providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 48070768f6edb..245b047e7dfbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -172,7 +172,7 @@ private[master] class Master( new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => - val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) + val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) .newInstance(conf, SerializationExtension(actorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index e6615a3174ce1..ef5a7e35ad562 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -128,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage { */ def fromJson(json: String): SubmitRestProtocolMessage = { val className = parseAction(json) - val clazz = Class.forName(packagePrefix + "." + className) + val clazz = Utils.classForName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 2d6be3042c905..6799f78ec0c19 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -53,7 +53,7 @@ object DriverWrapper { Thread.currentThread.setContextClassLoader(loader) // Delegate to supplied main class - val clazz = Class.forName(mainClass, true, loader) + val clazz = Utils.classForName(mainClass) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index e89d076802215..5181142c5f80e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -149,6 +149,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val ibmVendor = System.getProperty("java.vendor").contains("IBM") var totalMb = 0 try { + // scalastyle:off classforname val bean = ManagementFactory.getOperatingSystemMXBean() if (ibmVendor) { val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") @@ -159,6 +160,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt } + // scalastyle:on classforname } catch { case e: Exception => { totalMb = 2*1024 diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f7ef92bc80f91..1a02051c87f19 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -356,7 +356,7 @@ private[spark] class Executor( logInfo("Using REPL class URI: " + classUri) try { val _userClassPathFirst: java.lang.Boolean = userClassPathFirst - val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") + val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], classOf[ClassLoader], classOf[Boolean]) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0d8ac1f80a9f4..607d5a321efca 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -63,8 +63,7 @@ private[spark] object CompressionCodec { def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) val codec = try { - val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) - .getConstructor(classOf[SparkConf]) + val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { case e: ClassNotFoundException => None diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 818f7a4c8d422..87df42748be44 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.util.{Utils => SparkUtils} private[spark] trait SparkHadoopMapRedUtil { @@ -64,10 +65,10 @@ trait SparkHadoopMapRedUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + SparkUtils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + SparkUtils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 390d148bc97f9..943ebcb7bd0a1 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.util.Utils private[spark] trait SparkHadoopMapReduceUtil { @@ -46,7 +47,7 @@ trait SparkHadoopMapReduceUtil { isMap: Boolean, taskId: Int, attemptId: Int): TaskAttemptID = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") + val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") try { // First, attempt to use the old-style constructor that takes a boolean isMap // (not available in YARN) @@ -57,7 +58,7 @@ trait SparkHadoopMapReduceUtil { } catch { case exc: NoSuchMethodException => { // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( taskTypeClass, if (isMap) "MAP" else "REDUCE") @@ -71,10 +72,10 @@ trait SparkHadoopMapReduceUtil { private def firstAvailableClass(first: String, second: String): Class[_] = { try { - Class.forName(first) + Utils.classForName(first) } catch { case e: ClassNotFoundException => - Class.forName(second) + Utils.classForName(second) } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index ed5131c79fdc5..67f64d5e278de 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,6 +20,8 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.util.Utils + import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -166,7 +168,7 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Class.forName(classPath).newInstance() + val source = Utils.classForName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) @@ -182,7 +184,7 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Class.forName(classPath) + val sink = Utils.classForName(classPath) .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index bee59a437f120..f1c17369cb48c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -383,11 +383,11 @@ private[spark] object HadoopRDD extends Logging { private[spark] class SplitInfoReflections { val inputSplitWithLocationInfo = - Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") val isInMemory = splitLocationInfo.getMethod("isInMemory") val getLocation = splitLocationInfo.getMethod("getLocation") } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 1709bdf560b6f..c9fcc7a36cc04 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -39,8 +39,7 @@ private[spark] object RpcEnv { val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "akka") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] + Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } def create( diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 698d1384d580d..4a5274b46b7a0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -62,8 +62,11 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index ed35cffe968f8..7cb6e080533ad 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -102,6 +102,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) try { + // scalastyle:off classforname // Use the default classloader when calling the user registrator. Thread.currentThread.setContextClassLoader(classLoader) // Register classes given through spark.kryo.classesToRegister. @@ -111,6 +112,7 @@ class KryoSerializer(conf: SparkConf) userRegistrator .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } + // scalastyle:on classforname } catch { case e: Exception => throw new SparkException(s"Failed to register classes with Kryo", e) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cc2f0506817d3..a1b1e1631eafb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -407,7 +407,9 @@ private[spark] object SerializationDebugger extends Logging { /** ObjectStreamClass$ClassDataSlot.desc field */ val DescField: Field = { + // scalastyle:off classforname val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + // scalastyle:on classforname f.setAccessible(true) f } diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 291394ed34816..db965d54bafd6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -192,7 +192,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) try { - val instance = Class.forName(clsName) + val instance = Utils.classForName(clsName) .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 305de4c75539d..43626b4ef4880 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -448,10 +448,12 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? && argTypes(0).getInternalName == myName) { + // scalastyle:off classforname output += Class.forName( owner.replace('/', '.'), false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname } } } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 0180399c9dad5..7d84468f62ab1 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -124,9 +124,11 @@ object SizeEstimator extends Logging { val server = ManagementFactory.getPlatformMBeanServer() // NOTE: This should throw an exception in non-Sun JVMs + // scalastyle:off classforname val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", Class.forName("java.lang.String")) + // scalastyle:on classforname val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b6b932104a94d..e6374f17d858f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -113,8 +113,11 @@ private[spark] object Utils extends Logging { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } } ois.readObject.asInstanceOf[T] } @@ -177,12 +180,16 @@ private[spark] object Utils extends Logging { /** Determines whether the provided class is loadable in the current thread. */ def classIsLoadable(clazz: String): Boolean = { + // scalastyle:off classforname Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess + // scalastyle:on classforname } + // scalastyle:off classforname /** Preferred alternative to Class.forName(className) */ def classForName(className: String): Class[_] = { Class.forName(className, true, getContextOrSparkClassLoader) + // scalastyle:on classforname } /** @@ -2266,7 +2273,7 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { case Success(shmClass) => val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() .asInstanceOf[Int] diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 1d8fade90f398..418763f4e5ffa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -179,6 +179,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test("object files of classes from a JAR") { + // scalastyle:off classforname val original = Thread.currentThread().getContextClassLoader val className = "FileSuiteObjectFileTest" val jar = TestUtils.createJarWithClasses(Seq(className)) @@ -201,6 +202,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { finally { Thread.currentThread().setContextClassLoader(original) } + // scalastyle:on classforname } test("write SequenceFile using new Hadoop API") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index f89e3d0a49920..dba46f101c580 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.scalatest.PrivateMethodTester +import org.apache.spark.util.Utils import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -131,7 +132,7 @@ class SparkContextSchedulerCreationSuite def testYarn(master: String, expectedClassName: String) { try { val sched = createTaskScheduler(master) - assert(sched.getClass === Class.forName(expectedClassName)) + assert(sched.getClass === Utils.classForName(expectedClassName)) } catch { case e: SparkException => assert(e.getMessage.contains("YARN mode not available")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index e7878bde6fcb0..343d28eef8359 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -541,8 +541,8 @@ object JarCreationTest extends Logging { val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 08215a2bafc09..05013fbc49b8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -22,11 +22,12 @@ import java.sql._ import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.util.Utils class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { before { - Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver") val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") try { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 63a8480c9b57b..353b97469cd11 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -59,7 +59,9 @@ object KryoDistributedTest { class AppJarRegistrator extends KryoRegistrator { override def registerClasses(k: Kryo) { val classLoader = Thread.currentThread.getContextClassLoader + // scalastyle:off classforname k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + // scalastyle:on classforname } } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 42125547436cb..d3d464e84ffd7 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -84,7 +84,9 @@ class MutableURLClassLoaderSuite extends SparkFunSuite { try { sc.makeRDD(1 to 5, 2).mapPartitions { x => val loader = Thread.currentThread().getContextClassLoader + // scalastyle:off classforname Class.forName(className, true, loader).newInstance() + // scalastyle:on classforname Seq().iterator }.count() } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index 17cbc6707b5ea..d87b86932dd41 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -113,7 +113,9 @@ private[sink] object Logging { try { // We use reflection here to handle the case where users remove the // slf4j-to-jul bridge order to route their logs to JUL. + // scalastyle:off classforname val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + // scalastyle:on classforname bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] if (!installed) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 5c07b415cd796..74a7de18d4161 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -121,7 +121,7 @@ private[graphx] object BytecodeUtils { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { - methodsInvoked.add((Class.forName(owner.replace("/", ".")), name)) + methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) } } } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 8791618bd355e..4ee605fd7f11e 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1079,8 +1079,10 @@ import org.apache.spark.annotation.DeveloperApi throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex) private def load(path: String): Class[_] = { + // scalastyle:off classforname try Class.forName(path, true, classLoader) catch { case ex: Throwable => evalError(path, unwrap(ex)) } + // scalastyle:on classforname } lazy val evalClass = load(evalPath) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 49611703798e8..b5e2e882d2254 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,17 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + Class\.forName + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 57718228e490f..da83a7f0ba379 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -27,6 +27,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils /** @@ -146,7 +147,7 @@ object DataType { ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => - Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 477dea9164726..46bd60daa1f78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -554,8 +554,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. - val localBeanInfo = Introspector.getBeanInfo( - Class.forName(className, true, Utils.getContextOrSparkClassLoader)) + val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) val extractors = localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 704cf56f38265..e0bea65a15f36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.util.Utils /** * Relation that consists of data stored in a Parquet columnar format. @@ -108,7 +109,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[ParquetLog].getName) + Utils.classForName(classOf[ParquetLog].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -119,12 +120,12 @@ private[sql] object ParquetRelation { // Disables a WARN log message in ParquetOutputCommitter. We first ensure that // ParquetOutputCommitter is loaded and the static LOG field gets initialized. // See https://issues.apache.org/jira/browse/SPARK-5968 for details - Class.forName(classOf[ParquetOutputCommitter].getName) + Utils.classForName(classOf[ParquetOutputCommitter].getName) JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. // See https://issues.apache.org/jira/browse/PARQUET-220 for details - Class.forName(classOf[ParquetRecordReader[_]].getName) + Utils.classForName(classOf[ParquetRecordReader[_]].getName) JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 566a52dc1b784..0f82f13088d39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" @@ -46,7 +47,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { import ctx.sql before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test // usage of parameters from OPTIONS clause in queries. val properties = new Properties() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index d949ef42267ec..84b52ca2c733c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" @@ -41,7 +42,7 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { import ctx.sql before { - Class.forName("org.h2.Driver") + Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 301aa5a6411e2..39b31523e07cb 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -417,7 +417,7 @@ object ServerMode extends Enumeration { } abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { - Class.forName(classOf[HiveDriver].getCanonicalName) + Utils.classForName(classOf[HiveDriver].getCanonicalName) private def jdbcUri = if (mode == ServerMode.http) { s"""jdbc:hive2://localhost:$serverPort/ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index d65d29daacf31..dc355690852bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -78,9 +78,7 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, - Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) - .asInstanceOf[Class[Deserializer]], + Utils.classForName(relation.tableDesc.getSerdeClassName).asInstanceOf[Class[Deserializer]], filterOpt = None) /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 1f280c642979a..8adda54754230 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -21,9 +21,6 @@ import java.io.{File, PrintStream} import java.util.{Map => JMap} import javax.annotation.concurrent.GuardedBy -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.util.CircularBuffer - import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -37,7 +34,9 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Driver, metadata} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.util.{CircularBuffer, Utils} /** @@ -249,10 +248,10 @@ private[hive] class ClientWrapper( } private def toInputFormat(name: String) = - Class.forName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] private def toOutputFormat(name: String) = - Class.forName(name) + Utils.classForName(name) .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] private def toQlTable(table: HiveTable): metadata.Table = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 917900e5f46dc..bee2ecbedb244 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -120,8 +120,8 @@ object SparkSubmitClassLoaderTest extends Logging { logInfo("Testing load classes at the driver side.") // First, we load classes at driver side. try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => throw new Exception("Could not load user class from jar:\n", t) @@ -131,8 +131,8 @@ object SparkSubmitClassLoaderTest extends Logging { val result = df.mapPartitions { x => var exception: String = null try { - Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) - Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + Utils.classForName(args(0)) + Utils.classForName(args(1)) } catch { case t: Throwable => exception = t + "\n" + t.getStackTraceString diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index f5d41858646e4..9f2117ada61c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -22,7 +22,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, EventLoop, ManualClock} +import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -47,11 +47,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.util.SystemClock") try { - Class.forName(clockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(clockClass).newInstance().asInstanceOf[Clock] } catch { case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") => val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark") - Class.forName(newClockClass).newInstance().asInstanceOf[Clock] + Utils.classForName(newClockClass).newInstance().asInstanceOf[Clock] } } diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 9483d2b692ab5..9418beb6b3e3a 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -15,6 +15,7 @@ * limitations under the License. */ +// scalastyle:off classforname package org.apache.spark.tools import java.io.File @@ -188,3 +189,4 @@ object GenerateMIMAIgnore { classes } } +// scalastyle:on classforname diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f0af6f875f523..f86b6d1e5d7bc 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -732,9 +732,9 @@ private[spark] class Client( } val amClass = if (isClusterMode) { - Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName + Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName + Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs From e965a798d09a9fba61b104c5cc0b65cdc28d27f6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 14 Jul 2015 17:21:48 -0700 Subject: [PATCH 08/46] [SPARK-9045] Fix Scala 2.11 build break in UnsafeExternalRowSorter This fixes a compilation break in under Scala 2.11: ``` [error] /home/jenkins/workspace/Spark-Master-Scala211-Compile/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java:135: error: is not abstract and does not override abstract method minBy(Function1,Ordering) in TraversableOnce [error] return new AbstractScalaRowIterator() { [error] ^ [error] where B,A are type-variables: [error] B extends Object declared in method minBy(Function1,Ordering) [error] A extends Object declared in interface TraversableOnce [error] 1 error ``` The workaround for this is to make `AbstractScalaRowIterator` into a concrete class. Author: Josh Rosen Closes #7405 from JoshRosen/SPARK-9045 and squashes the following commits: cbcbb4c [Josh Rosen] Forgot that we can't use the ??? operator anymore 577ba60 [Josh Rosen] [SPARK-9045] Fix Scala 2.11 build break in UnsafeExternalRowSorter. --- .../apache/spark/sql/AbstractScalaRowIterator.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala index cfefb13e7721e..1090bdb5a4bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.InternalRow - /** * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to - * `Row` in order to work around a spurious IntelliJ compiler error. + * `Row` in order to work around a spurious IntelliJ compiler error. This cannot be an abstract + * class because that leads to compilation errors under Scala 2.11. */ -private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow] +private[spark] class AbstractScalaRowIterator[T] extends Iterator[T] { + override def hasNext: Boolean = throw new NotImplementedError + + override def next(): T = throw new NotImplementedError +} From cc57d705e732aefc2f3d3f438e84d71705b2eb65 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 14 Jul 2015 18:55:34 -0700 Subject: [PATCH 09/46] [SPARK-9050] [SQL] Remove unused newOrdering argument from Exchange (cleanup after SPARK-8317) SPARK-8317 changed the SQL Exchange operator so that it no longer pushed sorting into Spark's shuffle layer, a change which allowed more efficient SQL-specific sorters to be used. This patch performs some leftover cleanup based on those changes: - Exchange's constructor should no longer accept a `newOrdering` since it's no longer used and no longer works as expected. - `addOperatorsIfNecessary` looked at shuffle input's output ordering to decide whether to sort, but this is the wrong node to be examining: it needs to look at whether the post-shuffle node has the right ordering, since shuffling will not preserve row orderings. Thanks to davies for spotting this. Author: Josh Rosen Closes #7407 from JoshRosen/SPARK-9050 and squashes the following commits: e70be50 [Josh Rosen] No need to wrap line e866494 [Josh Rosen] Refactor addOperatorsIfNecessary to make code clearer 2e467da [Josh Rosen] Remove `newOrdering` from Exchange. --- .../apache/spark/sql/execution/Exchange.scala | 37 ++++++++----------- .../spark/sql/execution/SparkStrategies.scala | 3 +- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 4b783e30d95e1..feea4f239c04d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -35,21 +35,13 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn /** * :: DeveloperApi :: - * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each - * resulting partition based on expressions from the partition key. It is invalid to construct an - * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. + * Performs a shuffle that will result in the desired `newPartitioning`. */ @DeveloperApi -case class Exchange( - newPartitioning: Partitioning, - newOrdering: Seq[SortOrder], - child: SparkPlan) - extends UnaryNode { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning - override def outputOrdering: Seq[SortOrder] = newOrdering - override def output: Seq[Attribute] = child.output /** @@ -279,23 +271,24 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ partitioning: Partitioning, rowOrdering: Seq[SortOrder], child: SparkPlan): SparkPlan = { - val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering - val needsShuffle = child.outputPartitioning != partitioning - val withShuffle = if (needsShuffle) { - Exchange(partitioning, Nil, child) - } else { - child + def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { + if (child.outputPartitioning != partitioning) { + Exchange(partitioning, child) + } else { + child + } } - val withSort = if (needSort) { - sqlContext.planner.BasicOperators.getSortOperator( - rowOrdering, global = false, withShuffle) - } else { - withShuffle + def addSortIfNecessary(child: SparkPlan): SparkPlan = { + if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) { + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + child + } } - withSort + addSortIfNecessary(addShuffleIfNecessary(child)) } if (meetsRequirements && compatible && !needsAnySort) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce25af58b6cab..73b463471ec5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -360,8 +360,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.RepartitionByExpression(expressions, child) => - execution.Exchange( - HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil + execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil From f957796c4b3c3cd95edfc64500a045f7e810ee87 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 14 Jul 2015 19:20:49 -0700 Subject: [PATCH 10/46] [SPARK-8820] [STREAMING] Add a configuration to set checkpoint dir. Add a configuration to set checkpoint directory for convenience to user. [Jira Address](https://issues.apache.org/jira/browse/SPARK-8820) Author: huangzhaowei Closes #7218 from SaintBacchus/SPARK-8820 and squashes the following commits: d49fe4b [huangzhaowei] Rename the configuration name 66ea47c [huangzhaowei] Add the unit test. dd0acc1 [huangzhaowei] [SPARK-8820][Streaming] Add a configuration to set checkpoint dir. --- .../org/apache/spark/streaming/StreamingContext.scala | 2 ++ .../apache/spark/streaming/StreamingContextSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 6b78a82e68c24..92438f1b1fbf7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -201,6 +201,8 @@ class StreamingContext private[streaming] ( private var shutdownHookRef: AnyRef = _ + conf.getOption("spark.streaming.checkpoint.directory").foreach(checkpoint) + /** * Return the associated Spark context */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 289a159d8990a..f588cf5bc1e7c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -115,6 +115,15 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) } + test("checkPoint from conf") { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) + val ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.checkpointDir != null) + } + test("state matching") { import StreamingContextState._ assert(INITIALIZED === INITIALIZED) From bb870e72f42b6ce8d056df259f6fcf41808d7ed2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 14 Jul 2015 19:54:02 -0700 Subject: [PATCH 11/46] [SPARK-5523] [CORE] [STREAMING] Add a cache for hostname in TaskMetrics to decrease the memory usage and GC overhead Hostname in TaskMetrics will be created through deserialization, mostly the number of hostname is only the order of number of cluster node, so adding a cache layer to dedup the object could reduce the memory usage and alleviate GC overhead, especially for long-running and fast job generation applications like Spark Streaming. Author: jerryshao Author: Saisai Shao Closes #5064 from jerryshao/SPARK-5523 and squashes the following commits: 3e2412a [jerryshao] Address the comments b092a81 [Saisai Shao] Add a pool to cache the hostname --- .../apache/spark/executor/TaskMetrics.scala | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index a3b4561b07e7f..e80feeeab4142 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,11 +17,15 @@ package org.apache.spark.executor +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -210,10 +214,26 @@ class TaskMetrics extends Serializable { private[spark] def updateInputMetrics(): Unit = synchronized { inputMetrics.foreach(_.updateBytesRead()) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + // Get the hostname from cached data, since hostname is the order of number of nodes in + // cluster, so using cached hostname will decrease the object number and alleviate the GC + // overhead. + _hostname = TaskMetrics.getCachedHostName(_hostname) + } } private[spark] object TaskMetrics { + private val hostNameCache = new ConcurrentHashMap[String, String]() + def empty: TaskMetrics = new TaskMetrics + + def getCachedHostName(host: String): String = { + val canonicalHost = hostNameCache.putIfAbsent(host, host) + if (canonicalHost != null) canonicalHost else host + } } /** From 5572fd0c518acd2e4483ff41bea1eb1cffd543ce Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 14 Jul 2015 21:44:47 -0700 Subject: [PATCH 12/46] [HOTFIX] Adding new names to known contributors --- dev/create-release/known_translations | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 5f2671a6e5053..e462302f28423 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -129,3 +129,12 @@ yongtang - Yong Tang ypcat - Pei-Lun Lee zhichao-li - Zhichao Li zzcclp - Zhichao Zhang +979969786 - Yuming Wang +Rosstin - Rosstin Murphy +ameyc - Amey Chaugule +animeshbaranawal - Animesh Baranawal +cafreeman - Chris Freeman +lee19 - Lee +lockwobr - Brian Lockwood +navis - Navis Ryu +pparkkin - Paavo Parkkinen From f650a005e03ecd800c9005a496cc6a0d8eb68c93 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 14 Jul 2015 22:21:01 -0700 Subject: [PATCH 13/46] [SPARK-8808] [SPARKR] Fix assignments in SparkR. Author: Sun Rui Closes #7395 from sun-rui/SPARK-8808 and squashes the following commits: ce603bc [Sun Rui] Use '<-' instead of '='. 88590b1 [Sun Rui] Use '<-' instead of '='. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/client.R | 4 ++-- R/pkg/R/group.R | 4 ++-- R/pkg/R/utils.R | 4 ++-- R/pkg/inst/tests/test_binaryFile.R | 2 +- R/pkg/inst/tests/test_binary_function.R | 2 +- R/pkg/inst/tests/test_rdd.R | 4 ++-- R/pkg/inst/tests/test_textFile.R | 2 +- R/pkg/inst/tests/test_utils.R | 2 +- 9 files changed, 13 insertions(+), 13 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 60702824acb46..208813768e264 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1328,7 +1328,7 @@ setMethod("write.df", jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] = path + options[['path']] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 78c7a3037ffac..6f772158ddfe8 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) { determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { - sparkSubmitBinName = "spark-submit" + sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName = "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit.cmd" } sparkSubmitBinName } diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 8f1c68f7c4d28..576ac72f40fc0 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -87,7 +87,7 @@ setMethod("count", setMethod("agg", signature(x = "GroupedData"), function(x, ...) { - cols = list(...) + cols <- list(...) stopifnot(length(cols) > 0) if (is.character(cols[[1]])) { cols <- varargsToEnv(...) @@ -97,7 +97,7 @@ setMethod("agg", if (!is.null(ns)) { for (n in ns) { if (n != "") { - cols[[n]] = alias(cols[[n]], n) + cols[[n]] <- alias(cols[[n]], n) } } } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea629a64f7158..950ba74dbe017 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, if (isInstanceOf(obj, "scala.Tuple2")) { # JavaPairRDD[Array[Byte], Array[Byte]]. - keyBytes = callJMethod(obj, "_1") - valBytes = callJMethod(obj, "_2") + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") res <- list(unserialize(keyBytes), unserialize(valBytes)) } else { diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ccaea18ecab2a..f2452ed97d2ea 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -20,7 +20,7 @@ context("functions on binary files") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 3be8c65a6c1a0..dca0657c57e0d 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", { expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b79692873cec3..6c3aaab8c711e 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", { expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", { actual <- collect(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 58318dfef71ab..a9cf83dbdbdb1 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -20,7 +20,7 @@ context("the textFile() function") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { fileName <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index aa0d2a66b9082..12df4cf4f65b7 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", { # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) - t = 4 # Override base::t in .GlobalEnv. + t <- 4 # Override base::t in .GlobalEnv. f <- function(x) { x > t } newF <- cleanClosure(f) env <- environment(newF) From f23a721c10b64ec5c6768634fc5e9e7b60ee7ca8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Jul 2015 22:52:53 -0700 Subject: [PATCH 14/46] [SPARK-8993][SQL] More comprehensive type checking in expressions. This patch makes the following changes: 1. ExpectsInputTypes only defines expected input types, but does not perform any implicit type casting. 2. ImplicitCastInputTypes is a new trait that defines both expected input types, as well as performs implicit type casting. 3. BinaryOperator has a new abstract function "inputType", which defines the expected input type for both left/right. Concrete BinaryOperator expressions no longer perform any implicit type casting. 4. For BinaryOperators, convert NullType (i.e. null literals) into some accepted type so BinaryOperators don't need to handle NullTypes. TODOs needed: fix unit tests for error reporting. I'm intentionally not changing anything in aggregate expressions because yhuai is doing a big refactoring on that right now. Author: Reynold Xin Closes #7348 from rxin/typecheck and squashes the following commits: 8fcf814 [Reynold Xin] Fixed ordering of cases. 3bb63e7 [Reynold Xin] Style fix. f45408f [Reynold Xin] Comment update. aa7790e [Reynold Xin] Moved RemoveNullTypes into ImplicitTypeCasts. 438ea07 [Reynold Xin] space d55c9e5 [Reynold Xin] Removes NullTypes. 360d124 [Reynold Xin] Fixed the rule. fb66657 [Reynold Xin] Convert NullType into some accepted type for BinaryOperators. 2e22330 [Reynold Xin] Fixed unit tests. 4932d57 [Reynold Xin] Style fix. d061691 [Reynold Xin] Rename existing ExpectsInputTypes -> ImplicitCastInputTypes. e4727cc [Reynold Xin] BinaryOperator should not be doing implicit cast. d017861 [Reynold Xin] Improve expression type checking. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/analysis/HiveTypeCoercion.scala | 43 ++++++---- .../expressions/ExpectsInputTypes.scala | 17 +++- .../sql/catalyst/expressions/Expression.scala | 44 +++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 84 ++++++++----------- .../sql/catalyst/expressions/bitwise.scala | 30 +++---- .../spark/sql/catalyst/expressions/math.scala | 18 ++-- .../spark/sql/catalyst/expressions/misc.scala | 8 +- .../sql/catalyst/expressions/predicates.scala | 83 ++++++++++-------- .../expressions/stringOperations.scala | 36 ++++---- .../spark/sql/catalyst/util/TypeUtils.scala | 8 -- .../spark/sql/types/AbstractDataType.scala | 35 ++++++++ .../analysis/AnalysisErrorSuite.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 6 +- .../analysis/HiveTypeCoercionSuite.scala | 56 +++++++++++++ .../spark/sql/MathExpressionsSuite.scala | 1 - 17 files changed, 309 insertions(+), 165 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed69c42dcb825..6b1a94e4b2ad4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8cb71995eb818..15da5eecc8d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -214,19 +214,6 @@ object HiveTypeCoercion { } Union(newLeft, newRight) - - // Also widen types for BinaryOperator. - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - b.makeCopy(Array(newLeft, newRight)) - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. - } } } @@ -672,20 +659,44 @@ object HiveTypeCoercion { } /** - * Casts types according to the expected input types for Expressions that have the trait - * [[ExpectsInputTypes]]. + * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tighest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.makeCopy(Array(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + + case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) + + case e: ExpectsInputTypes if e.inputTypes.nonEmpty => + // Convert NullType into some specific target type for ExpectsInputTypes that don't do + // general implicit casting. + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + if (in.dataType == NullType && !expected.acceptsType(NullType)) { + Cast(in, expected.defaultConcreteType) + } else { + in + } + } + e.withNewChildren(children) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 3eb0eb195c80d..ded89e85dea79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -19,10 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType - +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. + * + * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * + * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. */ trait ExpectsInputTypes { self: Expression => @@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression => val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { @@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression => } } } + + +/** + * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. + */ +trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression => + // No other methods +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 54ec10444c4f3..3f19ac2b592b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the basic expression abstract classes in Catalyst, including: +// Expression: the base expression abstract class +// LeafExpression +// UnaryExpression +// BinaryExpression +// BinaryOperator +// +// For details, see their classdocs. +//////////////////////////////////////////////////////////////////////////////////////////////////// /** + * An expression in Catalyst. + * * If an expression wants to be exposed in the function registry (so users can call it with * "name(arguments...)", the concrete implementation must be a case class whose constructor * arguments are all Expressions types. @@ -335,15 +347,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * A [[BinaryExpression]] that is an operator, with two properties: + * + * 1. The string representation is "x symbol y", rather than "funcName(x, y)". + * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * the analyzer will find the tightest common type and do the proper type casting. */ -abstract class BinaryOperator extends BinaryExpression { +abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { self: Product => + /** + * Expected input type from both left/right child expressions, similar to the + * [[ImplicitCastInputTypes]] trait. + */ + def inputType: AbstractDataType + def symbol: String override def toString: String = s"($left $symbol $right)" + + override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) + + override def checkInputDataTypes(): TypeCheckResult = { + // First call the checker for ExpectsInputTypes, and then check whether left and right have + // the same type. + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 6fb3343bb63f2..22687acd68a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes { + inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8476af4a5d8d6..1a55a0876f303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -abstract class UnaryArithmetic extends UnaryExpression { - self: Product => + +case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = child.dataType -} -case class UnaryMinus(child: Expression) extends UnaryArithmetic { override def toString: String = s"-$child" - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "operator -") - private lazy val numeric = TypeUtils.getNumeric(dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { @@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { protected override def nullSafeEval(input: Any): Any = numeric.negate(input) } -case class UnaryPositive(child: Expression) extends UnaryArithmetic { +case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) extends UnaryArithmetic { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function abs") +case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "+" override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "-" override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "*" override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "/" override def decimalMethod: String = "$div" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] @@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function maxOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -335,10 +324,11 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function minOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index 2d47124d247e7..af1abbcd2239b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -29,10 +27,10 @@ import org.apache.spark.sql.types._ * Code generation inherited from BinaryArithmetic. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Bitwise + + override def symbol: String = "&" private lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -54,10 +52,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * Code generation inherited from BinaryArithmetic. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Bitwise + + override def symbol: String = "|" private lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -79,10 +77,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * Code generation inherited from BinaryArithmetic. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Bitwise + + override def symbol: String = "^" private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -101,11 +99,13 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryArithmetic { - override def toString: String = s"~$child" +case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise) + + override def dataType: DataType = child.dataType + + override def toString: String = s"~$child" private lazy val not: (Any) => Any = dataType match { case ByteType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..4b7fe05dd4980 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -55,7 +55,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends UnaryExpression with Serializable with ImplicitCastInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -89,7 +89,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with ImplicitCastInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -174,7 +174,7 @@ object Factorial { ) } -case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -251,7 +251,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -285,7 +285,7 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = @@ -329,7 +329,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -416,7 +416,7 @@ case class Pow(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -442,7 +442,7 @@ case class ShiftLeft(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -468,7 +468,7 @@ case class ShiftRight(left: Expression, right: Expression) * @param right the number of bits to right shift. */ case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 3b59cd431b871..a269ec4a1e6dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ -case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -55,7 +55,7 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -118,7 +118,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -138,7 +138,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ -case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f74fd04619714..aa6c30e2f79f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -33,12 +33,17 @@ object InterpretedPredicate { } } + +/** + * An [[Expression]] that returns a boolean value. + */ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType } + trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { @@ -70,7 +75,10 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { + +case class Not(child: Expression) + extends UnaryExpression with Predicate with ImplicitCastInputTypes { + override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -82,6 +90,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } + /** * Evaluates to `true` if `list` contains `value`. */ @@ -97,6 +106,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } + /** * Optimized version of In clause, when all filter values of In clause are * static. @@ -112,12 +122,12 @@ case class InSet(child: Expression, hset: Set[Any]) } } -case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left && $right)" +case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { + + override def inputType: AbstractDataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def symbol: String = "&&" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -161,12 +171,12 @@ case class And(left: Expression, right: Expression) } } -case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left || $right)" +case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputType: AbstractDataType = BooleanType + + override def symbol: String = "||" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -210,21 +220,10 @@ case class Or(left: Expression, right: Expression) } } + abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType)) { // faster version @@ -235,10 +234,12 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } } + private[sql] object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } + /** An extractor that matches both standard 3VL equality and null-safe equality. */ private[sql] object Equality { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { @@ -248,10 +249,12 @@ private[sql] object Equality { } } + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def inputType: AbstractDataType = AnyDataType + + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (left.dataType != BinaryType) input1 == input2 @@ -263,13 +266,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } + case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "<=>" override def nullable: Boolean = false - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) @@ -298,44 +303,48 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } + case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } + case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } + case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } + case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f64899c1ed84c..03b55ce5fe7cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -105,7 +105,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait String2StringExpression extends ExpectsInputTypes { +trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -142,7 +142,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -241,7 +241,7 @@ case class StringTrimRight(child: Expression) * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = substr @@ -265,7 +265,7 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -306,7 +306,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -344,7 +344,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -413,7 +413,7 @@ case class StringFormat(children: Expression*) extends Expression { * Returns the string which repeat the given string value n times. */ case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = times @@ -447,7 +447,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ -case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -467,7 +467,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ExpectsIn * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = pattern @@ -488,7 +488,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -555,7 +555,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -573,7 +573,7 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI * A function that return the Levenshtein distance between the two given strings. */ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ExpectsInputTypes { + with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -591,7 +591,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -608,7 +608,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -622,7 +622,7 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -636,7 +636,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput * If either argument is null, the result will also be null. */ case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = bin override def right: Expression = charset @@ -655,7 +655,7 @@ case class Decode(bin: Expression, charset: Expression) * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = value override def right: Expression = charset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 3148309a2166f..0103ddcf9cfb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -32,14 +32,6 @@ object TypeUtils { } } - def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { - if (t.isInstanceOf[IntegralType] || t == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") - } - } - def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[AtomicType] || t == NullType) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 32f87440b4e37..f5715f7a829ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -96,6 +96,24 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { + /** + * Types that can be ordered/compared. In the long run we should probably make this a trait + * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. + */ + val Ordered = TypeCollection( + BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType, + TimestampType, DateType, + StringType, BinaryType) + + /** + * Types that can be used in bitwise operations. + */ + val Bitwise = TypeCollection( + BooleanType, + ByteType, ShortType, IntegerType, LongType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -105,6 +123,23 @@ private[sql] object TypeCollection { } +/** + * An [[AbstractDataType]] that matches any concrete data types. + */ +protected[sql] object AnyDataType extends AbstractDataType { + + // Note that since AnyDataType matches any concrete types, defaultConcreteType should never + // be invoked. + override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException + + override private[sql] def simpleString: String = "any" + + override private[sql] def isSameType(other: DataType): Boolean = false + + override private[sql] def acceptsType(other: DataType): Boolean = true +} + + /** * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 9d0c69a2451d1..f0f17103991ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ case class TestFunction( children: Seq[Expression], - inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea6..5958acbe009ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -49,7 +49,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + s"differing types in '${expr.prettyString}' (int and boolean)") } test("check types for unary arithmetic") { @@ -58,7 +58,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } - test("check types for binary arithmetic") { + ignore("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic assertSuccess(Add('intField, 'stringField)) assertSuccess(Subtract('intField, 'stringField)) @@ -92,7 +92,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") } - test("check types for predicates") { + ignore("check types for predicates") { // We will cast String to Double for binary comparison assertSuccess(EqualTo('intField, 'stringField)) assertSuccess(EqualNullSafe('intField, 'stringField)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index acb9a433de903..8e9b20a3ebe42 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -194,6 +194,32 @@ class HiveTypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } + test("cast NullType for expresions that implement ExpectsInputTypes") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeUnaryExpression(Literal.create(null, NullType)), + AnyTypeUnaryExpression(Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeUnaryExpression(Literal.create(null, NullType)), + NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType))) + } + + test("cast NullType for binary operators") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + NumericTypeBinaryOperator( + Cast(Literal.create(null, NullType), DoubleType), + Cast(Literal.create(null, NullType), DoubleType))) + } + test("coalesce casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) @@ -302,3 +328,33 @@ class HiveTypeCoercionSuite extends PlanTest { ) } } + + +object HiveTypeCoercionSuite { + + case class AnyTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def dataType: DataType = NullType + } + + case class NumericTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def dataType: DataType = NullType + } + + case class AnyTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator with ExpectsInputTypes { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "anytype" + } + + case class NumericTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator with ExpectsInputTypes { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = NumericType + override def symbol: String = "numerictype" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 24bef21b999ea..b30b9f12258b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -375,6 +375,5 @@ class MathExpressionsSuite extends QueryTest { val df = Seq((1, -1, "abc")).toDF("a", "b", "c") checkAnswer(df.selectExpr("positive(a)"), Row(1)) checkAnswer(df.selectExpr("positive(b)"), Row(-1)) - checkAnswer(df.selectExpr("positive(c)"), Row("abc")) } } From c6b1a9e74e34267dc198e57a184c41498ca9d6a3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 14 Jul 2015 22:57:39 -0700 Subject: [PATCH 15/46] Revert SPARK-6910 and SPARK-9027 Revert #7216 and #7386. These patch seems to be causing quite a few test failures: ``` Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.GeneratedMethodAccessor322.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at org.apache.spark.sql.hive.client.Shim_v0_13.getPartitionsByFilter(HiveShim.scala:351) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$getPartitionsByFilter$1.apply(ClientWrapper.scala:320) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$getPartitionsByFilter$1.apply(ClientWrapper.scala:318) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:180) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:135) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:172) at org.apache.spark.sql.hive.client.ClientWrapper.getPartitionsByFilter(ClientWrapper.scala:318) at org.apache.spark.sql.hive.client.HiveTable.getPartitions(ClientInterface.scala:78) at org.apache.spark.sql.hive.MetastoreRelation.getHiveQlPartitions(HiveMetastoreCatalog.scala:670) at org.apache.spark.sql.hive.execution.HiveTableScan.doExecute(HiveTableScan.scala:137) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:90) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:90) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:147) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:89) at org.apache.spark.sql.execution.Exchange$$anonfun$doExecute$1.apply(Exchange.scala:164) at org.apache.spark.sql.execution.Exchange$$anonfun$doExecute$1.apply(Exchange.scala:151) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:48) ... 85 more Caused by: MetaException(message:Filtering is supported only on partition keys of type string) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$FilterBuilder.setError(ExpressionTree.java:185) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$LeafNode.getJdoFilterPushdownParam(ExpressionTree.java:452) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$LeafNode.generateJDOFilterOverPartitions(ExpressionTree.java:357) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$LeafNode.generateJDOFilter(ExpressionTree.java:279) at org.apache.hadoop.hive.metastore.parser.ExpressionTree$TreeNode.generateJDOFilter(ExpressionTree.java:243) at org.apache.hadoop.hive.metastore.parser.ExpressionTree.generateJDOFilterFragment(ExpressionTree.java:590) at org.apache.hadoop.hive.metastore.ObjectStore.makeQueryFilterString(ObjectStore.java:2417) at org.apache.hadoop.hive.metastore.ObjectStore.getPartitionsViaOrmFilter(ObjectStore.java:2029) at org.apache.hadoop.hive.metastore.ObjectStore.access$500(ObjectStore.java:146) at org.apache.hadoop.hive.metastore.ObjectStore$4.getJdoResult(ObjectStore.java:2332) ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark-QA-Test/job/Spark-Master-Maven-with-YARN/2945/HADOOP_PROFILE=hadoop-2.4,label=centos/testReport/junit/org.apache.spark.sql.hive.execution/SortMergeCompatibilitySuite/auto_sortmerge_join_16/ Author: Michael Armbrust Closes #7409 from marmbrus/revertMetastorePushdown and squashes the following commits: 92fabd3 [Michael Armbrust] Revert SPARK-6910 and SPARK-9027 5d3bdf2 [Michael Armbrust] Revert "[SPARK-9027] [SQL] Generalize metastore predicate pushdown" --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 58 +++++++------- .../org/apache/spark/sql/hive/HiveShim.scala | 1 - .../spark/sql/hive/HiveStrategies.scala | 4 +- .../sql/hive/client/ClientInterface.scala | 11 +-- .../spark/sql/hive/client/ClientWrapper.scala | 21 +++-- .../spark/sql/hive/client/HiveShim.scala | 72 +---------------- .../sql/hive/execution/HiveTableScan.scala | 7 +- .../spark/sql/hive/client/FiltersSuite.scala | 78 ------------------- .../spark/sql/hive/client/VersionsSuite.scala | 8 -- .../sql/hive/execution/PruningSuite.scala | 2 +- 10 files changed, 44 insertions(+), 218 deletions(-) delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5bdf68c83fca7..4b7a782c805a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -301,9 +301,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into ParquetRelation, so predicates to Hive metastore - // are empty. - val partitions = metastoreRelation.getHiveQlPartitions().map { p => + val partitions = metastoreRelation.hiveQlPartitions.map { p => val location = p.getLocation val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) @@ -646,6 +644,32 @@ private[hive] case class MetastoreRelation new Table(tTable) } + @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) @@ -666,34 +690,6 @@ private[hive] case class MetastoreRelation } ) - def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - table.getPartitions(predicates).map { p => - val tPartition = new org.apache.hadoop.hive.metastore.api.Partition - tPartition.setDbName(databaseName) - tPartition.setTableName(tableName) - tPartition.setValues(p.values) - - val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() - tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) - - val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) - - val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - - new Partition(hiveQlTable, tPartition) - } - } - /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index a357bb39ca7fd..d08c594151654 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9638a8201e190..ed359620a5f7f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies { InterpretedPredicate.create(castedPredicate) } - val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => + val partitions = relation.hiveQlPartitions.filter { part => val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 1656587d14835..0a1d761a52f88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -21,7 +21,6 @@ import java.io.PrintStream import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} -import org.apache.spark.sql.catalyst.expressions.Expression private[hive] case class HiveDatabase( name: String, @@ -72,12 +71,7 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { - predicates match { - case Nil => client.getAllPartitions(this) - case _ => client.getPartitionsByFilter(this, predicates) - } - } + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" @@ -138,9 +132,6 @@ private[hive] trait ClientInterface { /** Returns all partitions for the given table. */ def getAllPartitions(hTable: HiveTable): Seq[HivePartition] - /** Returns partitions filtered by predicates for the given table. */ - def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] - /** Loads a static partition into an existing table. */ def loadPartition( loadPath: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8adda54754230..53f457ad4f3cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -17,21 +17,25 @@ package org.apache.spark.sql.hive.client -import java.io.{File, PrintStream} -import java.util.{Map => JMap} +import java.io.{BufferedReader, InputStreamReader, File, PrintStream} +import java.net.URI +import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConversions._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} +import org.apache.hadoop.hive.metastore.api +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.metadata import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.Driver import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.Expression @@ -312,13 +316,6 @@ private[hive] class ClientWrapper( shim.getAllPartitions(client, qlTable).map(toHivePartition) } - override def getPartitionsByFilter( - hTable: HiveTable, - predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) - } - override def listTables(dbName: String): Seq[String] = withHiveState { client.getAllTables(dbName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d12778c7583df..1fa9d278e2a57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -31,11 +31,6 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegralType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to @@ -66,8 +61,6 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] - def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] - def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] @@ -116,7 +109,7 @@ private[client] sealed abstract class Shim { } -private[client] class Shim_v0_12 extends Shim with Logging { +private[client] class Shim_v0_12 extends Shim { private lazy val startMethod = findStaticMethod( @@ -203,17 +196,6 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { - // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. - // See HIVE-4888. - logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + - "Please use Hive 0.13 or higher.") - getAllPartitions(hive, table) - } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] @@ -285,12 +267,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { classOf[Hive], "getAllPartitionsOf", classOf[Table]) - private lazy val getPartitionsByFilterMethod = - findMethod( - classOf[Hive], - "getPartitionsByFilter", - classOf[Table], - classOf[String]) private lazy val getCommandProcessorMethod = findStaticMethod( classOf[CommandProcessorFactory], @@ -312,52 +288,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - /** - * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. - * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". - * - * Unsupported predicates are skipped. - */ - def convertFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) - .map(col => col.getName).toSet - - filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) - if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) - if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" - }.mkString(" and ") - } - - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { - - // Hive getPartitionsByFilter() takes a string that represents partition - // predicates like "str_key=\"value\" and int_key=1 ..." - val filter = convertFilters(table, predicates) - val partitions = - if (filter.isEmpty) { - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] - } else { - logDebug(s"Hive metastore filter is '$filter'.") - getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] - } - - partitions.toSeq - } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ba7eb15a1c0c6..d33da8242cc1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -44,7 +44,7 @@ private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Seq[Expression])( + partitionPruningPred: Option[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -56,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private[this] val boundPruningPred = partitionPruningPred.map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -133,8 +133,7 @@ case class HiveTableScan( protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala deleted file mode 100644 index 0efcf80bd4ea7..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * A set of tests for the filter conversion logic used when pushing partition pruning into the - * metastore - */ -class FiltersSuite extends SparkFunSuite with Logging { - private val shim = new Shim_v0_13 - - private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") - private val varCharCol = new FieldSchema() - varCharCol.setName("varchar") - varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) - testTable.setPartCols(varCharCol :: Nil) - - filterTest("string filter", - (a("stringcol", StringType) > Literal("test")) :: Nil, - "stringcol > \"test\"") - - filterTest("string filter backwards", - (Literal("test") > a("stringcol", StringType)) :: Nil, - "\"test\" > stringcol") - - filterTest("int filter", - (a("intcol", IntegerType) === Literal(1)) :: Nil, - "intcol = 1") - - filterTest("int filter backwards", - (Literal(1) === a("intcol", IntegerType)) :: Nil, - "1 = intcol") - - filterTest("int and string filter", - (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, - "1 = intcol and \"a\" = strcol") - - filterTest("skip varchar", - (Literal("") === a("varchar", StringType)) :: Nil, - "") - - private def filterTest(name: String, filters: Seq[Expression], result: String) = { - test(name){ - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") - } - } - } - - private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 3eb127e23d486..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.client import java.io.File import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils /** @@ -153,12 +151,6 @@ class VersionsSuite extends SparkFunSuite with Logging { client.getAllPartitions(client.getTable("default", "src_part")) } - test(s"$version: getPartitionsByFilter") { - client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( - AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), - Literal(1)))) - } - test(s"$version: loadPartition") { client.loadPartition( emptyDir, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index e83a7dc77e329..de6a41ce5bfcb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) val partValues = if (relation.table.isPartitioned) { - p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) + p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) } else { Seq.empty } From 4692769655e09d129a62a89a8ffb5d635675aa4d Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 14 Jul 2015 23:27:42 -0700 Subject: [PATCH 16/46] [SPARK-6259] [MLLIB] Python API for LDA I implemented the Python API for LDA. But I didn't implemented a method for `LDAModel.describeTopics()`, beause it's a little hard to implement it now. And adding document about that and an example code would fit for another issue. TODO: LDAModel.describeTopics() in Python must be also implemented. But it would be nice to fit for another issue. Implementing it is a little hard, since the return value of `describeTopics` in Scala consists of Tuple classes. Author: Yu ISHIKAWA Closes #6791 from yu-iskw/SPARK-6259 and squashes the following commits: 6855f59 [Yu ISHIKAWA] LDA inherits object 28bd165 [Yu ISHIKAWA] Change the place of testing code d7a332a [Yu ISHIKAWA] Remove the doc comment about the optimizer's default value 083e226 [Yu ISHIKAWA] Add the comment about the supported values and the default value of `optimizer` 9f8bed8 [Yu ISHIKAWA] Simplify casting faa9764 [Yu ISHIKAWA] Add some comments for the LDA paramters 98f645a [Yu ISHIKAWA] Remove the interface for `describeTopics`. Because it is not implemented. 57ac03d [Yu ISHIKAWA] Remove the unnecessary import in Python unit testing 73412c3 [Yu ISHIKAWA] Fix the typo 2278829 [Yu ISHIKAWA] Fix the indentation 39514ec [Yu ISHIKAWA] Modify how to cast the input data 8117e18 [Yu ISHIKAWA] Fix the validation problems by `lint-scala` 77fd1b7 [Yu ISHIKAWA] Not use LabeledPoint 68f0653 [Yu ISHIKAWA] Support some parameters for `ALS.train()` in Python 25ef2ac [Yu ISHIKAWA] Resolve conflicts with rebasing --- .../mllib/api/python/PythonMLLibAPI.scala | 33 ++++++++++ python/pyspark/mllib/clustering.py | 66 ++++++++++++++++++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e628059c4af8e..c58a64001d9a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib LDA.run() + */ + def trainLDAModel( + data: JavaRDD[java.util.List[Any]], + k: Int, + maxIterations: Int, + docConcentration: Double, + topicConcentration: Double, + seed: java.lang.Long, + checkpointInterval: Int, + optimizer: String): LDAModel = { + val algo = new LDA() + .setK(k) + .setMaxIterations(maxIterations) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) + .setCheckpointInterval(checkpointInterval) + .setOptimizer(optimizer) + + if (seed != null) algo.setSeed(seed) + + val documents = data.rdd.map(_.asScala.toArray).map { r => + r(0) match { + case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector]) + case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector]) + case _ => throw new IllegalArgumentException("input values contains invalid type value.") + } + } + algo.run(documents) + } + + /** * Java stub for Python mllib FPGrowth.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index ed4d78a2c6788..8a92f6911c24b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -31,13 +31,15 @@ from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector +from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable from pyspark.streaming import DStream __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', 'PowerIterationClusteringModel', 'PowerIterationClustering', - 'StreamingKMeans', 'StreamingKMeansModel'] + 'StreamingKMeans', 'StreamingKMeansModel', + 'LDA', 'LDAModel'] @inherit_doc @@ -563,6 +565,68 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) +class LDAModel(JavaModelWrapper): + + """ A clustering model derived from the LDA method. + + Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + Terminology + - "word" = "term": an element of the vocabulary + - "token": instance of a term appearing in a document + - "topic": multinomial distribution over words representing some concept + References: + - Original LDA paper (journal version): + Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + + >>> from pyspark.mllib.linalg import Vectors + >>> from numpy.testing import assert_almost_equal + >>> data = [ + ... [1, Vectors.dense([0.0, 1.0])], + ... [2, SparseVector(2, {0: 1.0})], + ... ] + >>> rdd = sc.parallelize(data) + >>> model = LDA.train(rdd, k=2) + >>> model.vocabSize() + 2 + >>> topics = model.topicsMatrix() + >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) + >>> assert_almost_equal(topics, topics_expect, 1) + """ + + def topicsMatrix(self): + """Inferred topics, where each topic is represented by a distribution over terms.""" + return self.call("topicsMatrix").toArray() + + def vocabSize(self): + """Vocabulary size (number of terms or terms in the vocabulary)""" + return self.call("vocabSize") + + +class LDA(object): + + @classmethod + def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, + topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): + """Train a LDA model. + + :param rdd: RDD of data points + :param k: Number of clusters you want + :param maxIterations: Number of iterations. Default to 20 + :param docConcentration: Concentration parameter (commonly named "alpha") + for the prior placed on documents' distributions over topics ("theta"). + :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") + for the prior placed on topics' distributions over terms. + :param seed: Random Seed + :param checkpointInterval: Period (in iterations) between checkpoints. + :param optimizer: LDAOptimizer used to perform the actual calculation. + Currently "em", "online" are supported. Default to "em". + """ + model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, + docConcentration, topicConcentration, seed, + checkpointInterval, optimizer) + return LDAModel(model) + + def _test(): import doctest import pyspark.mllib.clustering From 3f6296fed4ee10f53e728eb1e02f13338839b94d Mon Sep 17 00:00:00 2001 From: FlytxtRnD Date: Tue, 14 Jul 2015 23:29:02 -0700 Subject: [PATCH 17/46] [SPARK-8018] [MLLIB] KMeans should accept initial cluster centers as param This allows Kmeans to be initialized using an existing set of cluster centers provided as a KMeansModel object. This mode of initialization performs a single run. Author: FlytxtRnD Closes #6737 from FlytxtRnD/Kmeans-8018 and squashes the following commits: 94b56df [FlytxtRnD] style correction ef95ee2 [FlytxtRnD] style correction c446c58 [FlytxtRnD] documentation and numRuns warning change 06d13ef [FlytxtRnD] numRuns corrected d12336e [FlytxtRnD] numRuns variable modifications 07f8554 [FlytxtRnD] remove setRuns from setIntialModel e721dfe [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 242ead1 [FlytxtRnD] corrected == to === in assert 714acb5 [FlytxtRnD] added numRuns 60c8ce2 [FlytxtRnD] ignore runs parameter and initialModel test suite changed 582e6d9 [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 3f5fc8e [FlytxtRnD] test case modified and one runs condition added cd5dc5c [FlytxtRnD] Merge remote-tracking branch 'upstream/master' into Kmeans-8018 16f1b53 [FlytxtRnD] Merge branch 'Kmeans-8018', remote-tracking branch 'upstream/master' into Kmeans-8018 e9c35d7 [FlytxtRnD] Remove getInitialModel and match cluster count criteria 6959861 [FlytxtRnD] Accept initial cluster centers in KMeans --- docs/mllib-clustering.md | 1 + .../spark/mllib/clustering/KMeans.scala | 41 ++++++++++++++++--- .../spark/mllib/clustering/KMeansSuite.scala | 22 ++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d72dc20a5ad6e..0fc7036bffeaf 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0f8d6a399682d..68297130a7b03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -156,6 +156,21 @@ class KMeans private ( this } + // Initial cluster centers can be provided as a KMeansModel object rather than using the + // random or k-means|| initializationMode + private var initialModel: Option[KMeansModel] = None + + /** + * Set the initial starting point, bypassing the random initialization or k-means|| + * The condition model.k == this.k must be met, failure results + * in an IllegalArgumentException. + */ + def setInitialModel(model: KMeansModel): this.type = { + require(model.k == k, "mismatched cluster count") + initialModel = Some(model) + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -193,20 +208,34 @@ class KMeans private ( val initStartTime = System.nanoTime() - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) + // Only one run is allowed when initialModel is given + val numRuns = if (initialModel.nonEmpty) { + if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") + 1 } else { - initKMeansParallel(data) + runs } + val centers = initialModel match { + case Some(kMeansCenters) => { + Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + } + case None => { + if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + } + } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + " seconds.") - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) + val active = Array.fill(numRuns)(true) + val costs = Array.fill(numRuns)(0.0) - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) var iteration = 0 val iterationStartTime = System.nanoTime() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0dbbd7127444f..3003c62d9876c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("Initialize using given cluster centers") { + val points = Seq( + Vectors.dense(0.0, 0.0), + Vectors.dense(1.0, 0.0), + Vectors.dense(0.0, 1.0), + Vectors.dense(1.0, 1.0) + ) + val rdd = sc.parallelize(points, 3) + // creating an initial model + val initialModel = new KMeansModel(Array(points(0), points(2))) + + val returnModel = new KMeans() + .setK(2) + .setMaxIterations(0) + .setInitialModel(initialModel) + .run(rdd) + // comparing the returned model and the initial model + assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0)) + assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1)) + } + } object KMeansSuite extends SparkFunSuite { From f0e129740dc2442a21dfa7fbd97360df87291095 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 14 Jul 2015 23:30:41 -0700 Subject: [PATCH 18/46] [SPARK-8279][SQL]Add math function round JIRA: https://issues.apache.org/jira/browse/SPARK-8279 Author: Yijie Shen Closes #6938 from yijieshen/udf_round_3 and squashes the following commits: 07a124c [Yijie Shen] remove useless def children 392b65b [Yijie Shen] add negative scale test in DecimalSuite 61760ee [Yijie Shen] address reviews 302a78a [Yijie Shen] Add dataframe function test 31dfe7c [Yijie Shen] refactor round to make it readable 8c7a949 [Yijie Shen] rebase & inputTypes update 9555e35 [Yijie Shen] tiny style fix d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast c3b9839 [Yijie Shen] rely on implict cast to handle string input b0bff79 [Yijie Shen] make round's inner method's name more meaningful 9bd6930 [Yijie Shen] revert accidental change e6f44c4 [Yijie Shen] refactor eval and genCode 1b87540 [Yijie Shen] modify checkInputDataTypes using foldable 5486b2d [Yijie Shen] DataFrame API modification 2077888 [Yijie Shen] codegen versioned eval 6cd9a64 [Yijie Shen] refactor Round's constructor 9be894e [Yijie Shen] add round functions in o.a.s.sql.functions 7c83e13 [Yijie Shen] more tests on round 56db4bb [Yijie Shen] Add decimal support to Round 7e163ae [Yijie Shen] style fix 653d047 [Yijie Shen] Add math function round --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 203 +++++++++++++++++- .../ExpressionTypeCheckingSuite.scala | 17 ++ .../expressions/MathFunctionsSuite.scala | 44 ++++ .../sql/types/decimal/DecimalSuite.scala | 23 +- .../org/apache/spark/sql/functions.scala | 32 +++ .../spark/sql/MathExpressionsSuite.scala | 15 ++ .../execution/HiveCompatibilitySuite.scala | 7 +- 8 files changed, 329 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6b1a94e4b2ad4..ec75f51d5e4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -117,6 +117,7 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 4b7fe05dd4980..a7ad452ef4943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ +case class Round(child: Expression, scale: Expression) + extends BinaryExpression with ExpectsInputTypes { + + import BigDecimal.RoundingMode.HALF_UP + + def this(child: Expression) = this(child, Literal(0)) + + override def left: Expression = child + override def right: Expression = scale + + // round of Decimal would eval to null if it fails to `changePrecision` + override def nullable: Boolean = true + + override def foldable: Boolean = child.foldable + + override lazy val dataType: DataType = child.dataType match { + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) + case t => t + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f + } + } + + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] + + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) + } + } + } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { + child.dataType match { + case _: DecimalType => + val decimal = input1.asInstanceOf[Decimal] + if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + case ByteType => + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + case ShortType => + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + case IntegerType => + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + case LongType => + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + case FloatType => + val f = input1.asInstanceOf[Float] + if (f.isNaN || f.isInfinite) { + f + } else { + BigDecimal(f).setScale(_scale, HALF_UP).toFloat + } + case DoubleType => + val d = input1.asInstanceOf[Double] + if (d.isNaN || d.isInfinite) { + d + } else { + BigDecimal(d).setScale(_scale, HALF_UP).toDouble + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val ce = child.gen(ctx) + + val evaluationCode = child.dataType match { + case _: DecimalType => + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + }""" + case ByteType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case ShortType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case IntegerType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case LongType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } + } + + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } + """ + } + } + + override def prettyName: String = "round" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 5958acbe009ca..e885a18254ea0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { s"differing types in '${expr.prettyString}' (int and boolean)") } + def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains(errorMessage)) + } + test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") @@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), + "data type mismatch: argument 2 is expected to be of type int") + assertErrorWithImplicitCast(Round(Literal(null), 'complexField), + "data type mismatch: argument 2 is expected to be of type int") + assertSuccess(Round(Literal(null), Literal(null))) + assertError(Round('booleanField, 'intField), + "data type mismatch: argument 1 is expected to be of type numeric") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..52a874a9d89ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math.BigDecimal.RoundingMode + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite @@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round") { + val domain = -6 to 6 + val doublePi: Double = math.Pi + val shortPi: Short = 31415 + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + } + + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + } + (8 to 10).foreach { scale => + checkEvaluation(Round(bdPi, scale), null, EmptyRow) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 030bb6d21b18b..f0c849d1a1564 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester import scala.language.postfixOps class DecimalSuite extends SparkFunSuite with PrivateMethodTester { - test("creating decimals") { - /** Check that a Decimal has the given string representation, precision and scale */ - def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { - assert(d.toString === string) - assert(d.precision === precision) - assert(d.scale === scale) - } + /** Check that a Decimal has the given string representation, precision and scale */ + private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } + test("creating decimals with negative scale") { + checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3) + checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10) + checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10) + checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10) + } + test("double and long values") { /** Check that a Decimal converts to the given double and long values */ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d4e160ed8057..5119ee31d852d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1389,6 +1389,38 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Returns the value of the given column rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName), 0) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index b30b9f12258b9..087126bb2e513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..4ada64bc21966 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -918,8 +915,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + // "udf_round", turn this on after we figure out null vs nan vs infinity + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second", From 1bb8accbc95a0f0856a8bb715f1e94c3ff96a8c7 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 14 Jul 2015 23:50:57 -0700 Subject: [PATCH 19/46] [SPARK-8997] [MLLIB] Performance improvements in LocalPrefixSpan Improves the performance of LocalPrefixSpan by implementing optimizations proposed in [SPARK-8997](https://issues.apache.org/jira/browse/SPARK-8997) Author: Feynman Liang Author: Feynman Liang Author: Xiangrui Meng Closes #7360 from feynmanliang/SPARK-8997-improve-prefixspan and squashes the following commits: 59db2f5 [Feynman Liang] Merge pull request #1 from mengxr/SPARK-8997 91e4357 [Xiangrui Meng] update LocalPrefixSpan impl 9212256 [Feynman Liang] MengXR code review comments f055d82 [Feynman Liang] Fix failing scalatest 2e00cba [Feynman Liang] Depth first projections 70b93e3 [Feynman Liang] Performance improvements in LocalPrefixSpan, fix tests --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 95 ++++++++----------- .../apache/spark/mllib/fpm/PrefixSpan.scala | 5 +- .../spark/mllib/fpm/PrefixSpanSuite.scala | 14 +-- 3 files changed, 44 insertions(+), 70 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 39c48b084e550..7ead6327486cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -17,58 +17,49 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable + import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental /** - * - * :: Experimental :: - * * Calculate all patterns of a projected database in local. */ -@Experimental private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** * Calculate all patterns of a projected database. * @param minCount minimum count * @param maxPatternLength maximum pattern length - * @param prefix prefix - * @param projectedDatabase the projected dabase + * @param prefixes prefixes in reversed order + * @param database the projected database * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items), + * the key of pair is sequential pattern (a list of items in reversed order), * the value of pair is the pattern's count. */ def run( minCount: Long, maxPatternLength: Int, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => run(minCount, maxPatternLength, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns - } else { - frequentPatternAndCounts + prefixes: List[Int], + database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty + val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) + val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) + frequentItemAndCounts.iterator.flatMap { case (item, count) => + val newPrefixes = item :: prefixes + val newProjected = project(filteredDatabase, item) + Iterator.single((newPrefixes, count)) ++ + run(minCount, maxPatternLength, newPrefixes, newProjected) } } /** - * calculate suffix sequence following a prefix in a sequence - * @param prefix prefix - * @param sequence sequence + * Calculate suffix sequence immediately after the first occurrence of an item. + * @param item item to get suffix after + * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(prefix) + def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(item) if (index == -1) { Array() } else { @@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } + def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + database + .map(getSuffix(prefix, _)) + .filter(_.nonEmpty) + } + /** * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences sequences data - * @return array of item and count pair + * @param minCount the minimum count for an item to be frequent + * @param database database of sequences + * @return freq item to count map */ private def getFreqItemAndCounts( minCount: Long, - sequences: Array[Array[Int]]): Array[(Int, Long)] = { - sequences.flatMap(_.distinct) - .groupBy(x => x) - .mapValues(_.length.toLong) - .filter(_._2 >= minCount) - .toArray - } - - /** - * Get the frequent prefixes' projected database. - * @param prePrefix the frequent prefixes' prefix - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPatternAndProjectedDatabase( - prePrefix: Array[Int], - frequentPrefixes: Array[Int], - sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = sequences - .map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { x => - val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) - (prePrefix ++ Array(x), sub) - }.filter(x => x._2.nonEmpty) + database: Array[Array[Int]]): mutable.Map[Int, Long] = { + // TODO: use PrimitiveKeyOpenHashMap + val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + database.foreach { sequence => + sequence.distinct.foreach { item => + counts(item) += 1L + } + } + counts.filter(_._2 >= minCount) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9d8c60ef0fc45..6f52db7b073ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -150,8 +150,9 @@ class PrefixSpan private ( private def getPatternsInLocal( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { x => - LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) + data.flatMap { case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) + .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 413436d3db85f..9f107c89f6d80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD -class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { test("PrefixSpan using Integer type") { @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - val sortedExpectedValue = expectedValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - val sortedActualValue = actualValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - sortedExpectedValue.zip(sortedActualValue) - .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2) - .reduce(_&&_) + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } val prefixspan = new PrefixSpan() From 14935d846a4f6bcd4d2a448a8f112fa5dee769ba Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 15 Jul 2015 00:12:21 -0700 Subject: [PATCH 20/46] [HOTFIX][SQL] Unit test breaking. --- .../sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index e885a18254ea0..a4ce1825cab28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -60,9 +60,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertError(Abs('stringField), "function abs accepts numeric type") - assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + assertError(UnaryMinus('stringField), "expected to be of type numeric") + assertError(Abs('stringField), "expected to be of type numeric") + assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)") } ignore("check types for binary arithmetic") { From adb33d3665770daf2ccb8915d19e198be9dc3b47 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 15 Jul 2015 17:30:57 +0900 Subject: [PATCH 21/46] [SPARK-9012] [WEBUI] Escape Accumulators in the task table If running the following codes, the task table will be broken because accumulators aren't escaped. ``` val a = sc.accumulator(1, "") sc.parallelize(1 to 10).foreach(i => a += i) ``` Before this fix, screen shot 2015-07-13 at 8 02 44 pm After this fix, screen shot 2015-07-13 at 8 14 32 pm Author: zsxwing Closes #7369 from zsxwing/SPARK-9012 and squashes the following commits: a83c9b6 [zsxwing] Escape Accumulators in the task table --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff0a339a39c65..27b82aaddd2e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -692,7 +692,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gettingResultTime = getGettingResultTime(info, currentTime) val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} + val accumulatorsReadable = maybeAccumulators.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") From 20bb10f8644a92a57496b5df639008832b30e34d Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 15 Jul 2015 08:25:53 -0700 Subject: [PATCH 22/46] [SPARK-8706] [PYSPARK] [PROJECT INFRA] Add pylint checks to PySpark This adds Pylint checks to PySpark. For now this lazy installs using easy_install to /dev/pylint (similar to the pep8 script). We still need to figure out what rules to be allowed. Author: MechCoder Closes #7241 from MechCoder/pylint and squashes the following commits: 2fc7291 [MechCoder] Remove pylint test fail 6d883a2 [MechCoder] Silence warnings and make pylint tests fail to check if it works in jenkins f3a5e17 [MechCoder] undefined-variable ca8b749 [MechCoder] Minor changes 71629f8 [MechCoder] remove trailing whitespace 8498ff9 [MechCoder] Remove blacklisted arguments and pointless statements check 1dbd094 [MechCoder] Disable all checks for now 8b8aa8a [MechCoder] Add pylint configuration file 7871bb1 [MechCoder] [SPARK-8706] [PySpark] [Project infra] Add pylint checks to PySpark --- dev/lint-python | 57 ++++- pylintrc | 404 ++++++++++++++++++++++++++++++ python/pyspark/ml/param/shared.py | 4 +- python/pyspark/tests.py | 3 +- 4 files changed, 457 insertions(+), 11 deletions(-) create mode 100644 pylintrc diff --git a/dev/lint-python b/dev/lint-python index 0c3586462cb37..e02dff220eb87 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,12 +21,14 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" -PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" +PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pep8 at runtime so that we don't rely on it being installed on the build server. @@ -47,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then fi fi +# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should +# be set to the directory. +# dev/pylint should be appended to the PATH variable as well. +# Jenkins by default installs the pylint3 version, so for now this just checks the code quality +# of python3. +export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" +export "PYLINT_HOME=$PYTHONPATH" +export "PATH=$PYTHONPATH:$PATH" + +if [ ! -d "$PYLINT_HOME" ]; then + mkdir "$PYLINT_HOME" + # Redirect the annoying pylint installation output. + easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" + easy_install_status="$?" + + if [ "$easy_install_status" -ne 0 ]; then + echo "Unable to install pylint locally in \"$PYTHONPATH\"." + cat "$PYLINT_INSTALL_INFO" + exit "$easy_install_status" + fi + + rm "$PYLINT_INSTALL_INFO" + +fi + # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -61,13 +88,27 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "Python lint checks failed." - cat "$PYTHON_LINT_REPORT_PATH" + echo "PEP8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP8 checks passed." +fi + +rm "$PEP8_REPORT_PATH" + +for to_be_checked in "$PATHS_TO_CHECK" +do + pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +done + +if [ "${PIPESTATUS[0]}" -ne 0 ]; then + lint_status=1 + echo "Pylint checks failed." + cat "$PYLINT_REPORT_PATH" else - echo "Python lint checks passed." + echo "Pylint checks passed." fi -# rm "$PEP8_SCRIPT_PATH" -rm "$PYTHON_LINT_REPORT_PATH" +rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000000000..061775960393b --- /dev/null +++ b/pylintrc @@ -0,0 +1,404 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=pyspark.heapq3 + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + +# Allow optimization of some AST trees. This will activate a peephole AST +# optimizer, which will apply various small optimizations. For instance, it can +# be used to obtain the result of joining multiple strings with the addition +# operator. Joining a lot of strings can lead to a maximum recursion error in +# Pylint and this flag can prevent that. It has one side effect, the resulting +# AST will be different than the one from reality. +optimize-ast=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" + +# These errors are arranged in order of number of warning given in pylint. +# If you would like to improve the code quality of pyspark, remove any of these disabled errors +# run ./dev/lint-python and see if the errors raised by pylint can be fixed. + +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions= + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=__.*__ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_$|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis +ignored-modules= + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index bc088e4c29e26..595124726366d 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -444,7 +444,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -460,7 +460,7 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c5c0add49d02c..21225016805bc 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -893,7 +893,8 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) result = rdd.pipe('cat').collect() result.sort() - [self.assertEqual(x, y) for x, y in zip(data, result)] + for x, y in zip(data, result): + self.assertEqual(x, y) self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) From 6f6902597d5d687049c103bc0cf6da30919b92d8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 09:48:33 -0700 Subject: [PATCH 23/46] [SPARK-8840] [SPARKR] Add float coercion on SparkR JIRA: https://issues.apache.org/jira/browse/SPARK-8840 Currently the type coercion rules don't include float type. This PR simply adds it. Author: Liang-Chi Hsieh Closes #7280 from viirya/add_r_float_coercion and squashes the following commits: c86dc0e [Liang-Chi Hsieh] For comments. dbf0c1b [Liang-Chi Hsieh] Implicitly convert Double to Float based on provided schema. 733015a [Liang-Chi Hsieh] Add test case for DataFrame with float type. 30c2a40 [Liang-Chi Hsieh] Update test case. 52b5294 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_r_float_coercion 6f9159d [Liang-Chi Hsieh] Add another test case. 8db3244 [Liang-Chi Hsieh] schema also needs to support float. add test case. 0dcc992 [Liang-Chi Hsieh] Add float coercion on SparkR. --- R/pkg/R/deserialize.R | 1 + R/pkg/R/schema.R | 1 + R/pkg/inst/tests/test_sparkSQL.R | 26 +++++++++++++++++++ .../scala/org/apache/spark/api/r/SerDe.scala | 4 +++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 15 ++++++++--- 5 files changed, 44 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d961bbc383688..7d1f6b0819ed0 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -23,6 +23,7 @@ # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 15e2bdbd55d79..06df430687682 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -123,6 +123,7 @@ structField.character <- function(x, type, nullable = TRUE) { } options <- c("byte", "integer", + "float", "double", "numeric", "character", diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b0ea38854304e..76f74f80834a9 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce0..d5b4260bf4529 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -179,6 +179,7 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double // Long -> double // Array[Byte] -> raw @@ -215,6 +216,9 @@ private[spark] object SerDe { case "long" | "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 43b62f0e822f8..92861ab038f19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,6 +47,7 @@ private[r] object SQLUtils { dataType match { case "byte" => org.apache.spark.sql.types.ByteType case "integer" => org.apache.spark.sql.types.IntegerType + case "float" => org.apache.spark.sql.types.FloatType case "double" => org.apache.spark.sql.types.DoubleType case "numeric" => org.apache.spark.sql.types.DoubleType case "character" => org.apache.spark.sql.types.StringType @@ -68,7 +69,7 @@ private[r] object SQLUtils { def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size - val rowRDD = rdd.map(bytesToRow) + val rowRDD = rdd.map(bytesToRow(_, schema)) sqlContext.createDataFrame(rowRDD, schema) } @@ -76,12 +77,20 @@ private[r] object SQLUtils { df.map(r => rowToRBytes(r)) } - private[this] def bytesToRow(bytes: Array[Byte]): Row = { + private[this] def doConversion(data: Object, dataType: DataType): Object = { + data match { + case d: java.lang.Double if dataType == FloatType => + new java.lang.Float(d) + case _ => data + } + } + + private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - SerDe.readObject(dis) + doConversion(SerDe.readObject(dis), schema.fields(i).dataType) }.toSeq) } From fa4ec3606a965238423f977808163983c9d56e0a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 10:31:39 -0700 Subject: [PATCH 24/46] [SPARK-9020][SQL] Support mutable state in code gen expressions We can keep expressions' mutable states in generated class(like `SpecificProjection`) as member variables, so that we can read and modify them inside codegened expressions. Author: Wenchen Fan Closes #7392 from cloud-fan/mutable-state and squashes the following commits: eb3a221 [Wenchen Fan] fix order 73144d8 [Wenchen Fan] naming improvement 318f41d [Wenchen Fan] address more comments d43b65d [Wenchen Fan] address comments fd45c7a [Wenchen Fan] Support mutable state in code gen expressions --- .../scala/org/apache/spark/TaskContext.scala | 15 ++- .../expressions/codegen/CodeGenerator.scala | 17 +++- .../codegen/GenerateMutableProjection.scala | 4 + .../codegen/GenerateOrdering.scala | 38 +++++-- .../codegen/GeneratePredicate.scala | 4 + .../codegen/GenerateProjection.scala | 99 ++++++++++--------- .../sql/catalyst/expressions/random.scala | 29 +++++- .../MonotonicallyIncreasingID.scala | 19 +++- .../expressions/SparkPartitionID.scala | 12 ++- 9 files changed, 171 insertions(+), 66 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea0911..248339148d9b7 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -32,7 +32,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc == null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4ec..328d635de8743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,6 +56,18 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + /** + * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a + * 3-tuple: java type, variable name, code to init it. + * They will be kept as member variables in generated classes like `SpecificProjection`. + */ + val mutableStates: mutable.ArrayBuffer[(String, String, String)] = + mutable.ArrayBuffer.empty[(String, String, String)] + + def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = { + mutableStates += ((javaType, variableName, initialValue)) + } + val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName @@ -203,7 +215,10 @@ class CodeGenContext { def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } - +/** + * A wrapper for generated class, defines a `generate` method so that we can pass extra objects + * into generated class. + */ abstract class GeneratedClass { def generate(expressions: Array[Expression]): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index addb8023d9c0b..71e47d4f9b620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,6 +46,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -55,6 +58,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private $exprType[] expressions = null; private $mutableRowType mutableRow = null; + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d05dfc108e63a..856ff9f1f96f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -46,30 +46,47 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { val ctx = newCodeGenContext() - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = order.child.gen(ctx) - val evalB = order.child.gen(ctx) + val comparisons = ordering.map { order => + val eval = order.child.gen(ctx) val asc = order.direction == Ascending + val isNullA = ctx.freshName("isNullA") + val primitiveA = ctx.freshName("primitiveA") + val isNullB = ctx.freshName("isNullB") + val primitiveB = ctx.freshName("primitiveB") s""" i = a; - ${evalA.code} + boolean $isNullA; + ${ctx.javaType(order.child.dataType)} $primitiveA; + { + ${eval.code} + $isNullA = ${eval.isNull}; + $primitiveA = ${eval.primitive}; + } i = b; - ${evalB.code} - if (${evalA.isNull} && ${evalB.isNull}) { + boolean $isNullB; + ${ctx.javaType(order.child.dataType)} $primitiveB; + { + ${eval.code} + $isNullB = ${eval.isNull}; + $primitiveB = ${eval.primitive}; + } + if ($isNullA && $isNullB) { // Nothing - } else if (${evalA.isNull}) { + } else if ($isNullA) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.isNull}) { + } else if ($isNullB) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { return ${if (asc) "comp" else "-comp"}; } } """ }.mkString("\n") - + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -78,6 +95,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; + $mutableStates public SpecificOrdering($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 274a42cb69087..9e5a745d512e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,6 +40,9 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); @@ -47,6 +50,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; + $mutableStates public SpecificPredicate($exprType[] expr) { expressions = expr; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3c7ee9cc16599..3e5ca308dc31d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,6 +151,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") + val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); @@ -158,6 +162,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; @@ -165,65 +170,65 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { - return new SpecificRow(expressions, (InternalRow) r); + return new SpecificRow((InternalRow) r); } - } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { - $columns + $columns - public SpecificRow($exprType[] expressions, InternalRow i) { - $initColumns - } + public SpecificRow(InternalRow i) { + $initColumns + } - public int length() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } } - nullBits[i] = false; - switch (i) { - $updateCases + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - } - $specificAccessorFunctions - $specificMutatorFunctions - - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - @Override - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; + @Override + public boolean equals(Object other) { + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; + $columnChecks + return true; + } + return super.equals(other); } - return super.equals(other); - } - @Override - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${classOf[GenericInternalRow].getName}(arr); + } } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 6cdc3000382e2..e10ba55396664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -38,11 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val partitionId = TaskContext.get() match { - case null => 0 - case _ => TaskContext.get().partitionId() - } - @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) + @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) override def deterministic: Boolean = false @@ -61,6 +58,17 @@ case class Rand(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); + """ + } } /** Generate a random column with i.i.d. gaussian random distribution. */ @@ -73,4 +81,15 @@ case class Randn(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); + """ + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 437d143e53f3f..69a37750d7525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -40,6 +41,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L + @transient private lazy val partitionMask = TaskContext.getPartitionId.toLong << 33 + override def nullable: Boolean = false override def dataType: DataType = LongType @@ -47,6 +50,20 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def eval(input: InternalRow): Long = { val currentCount = count count += 1 - (TaskContext.get().partitionId().toLong << 33) + currentCount + partitionMask + currentCount + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val countTerm = ctx.freshName("count") + val partitionMaskTerm = ctx.freshName("partitionMask") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, + "((long) org.apache.spark.TaskContext.getPartitionId()) << 33") + + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; + $countTerm++; + """ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 822d3d8c9108d..5f1b514f2cff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -32,5 +33,14 @@ private[sql] case object SparkPartitionID extends LeafExpression { override def dataType: DataType = IntegerType - override def eval(input: InternalRow): Int = TaskContext.get().partitionId() + @transient private lazy val partitionId = TaskContext.getPartitionId + + override def eval(input: InternalRow): Int = partitionId + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val idTerm = ctx.freshName("partitionId") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()") + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" + } } From a9385271a9f6b97ec6aa619cf56ee556ba2fb0de Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Wed, 15 Jul 2015 10:43:38 -0700 Subject: [PATCH 25/46] [SPARK-8221][SQL]Add pmod function https://issues.apache.org/jira/browse/SPARK-8221 One concern is the result would be negative if the divisor is not positive( i.e pmod(7, -3) ), but the behavior is the same as hive. Author: zhichao.li Closes #6783 from zhichao-li/pmod2 and squashes the following commits: 7083eb9 [zhichao.li] update to the latest type checking d26dba7 [zhichao.li] add pmod --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/analysis/HiveTypeCoercion.scala | 6 ++ .../sql/catalyst/expressions/arithmetic.scala | 94 +++++++++++++++++++ .../ArithmeticExpressionSuite.scala | 16 +++- .../org/apache/spark/sql/functions.scala | 17 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 37 ++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ec75f51d5e4ff..d2678ce860701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -115,6 +115,7 @@ object FunctionRegistry { expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), + expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), expression[Round]("round"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 15da5eecc8d3c..25087915b5c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -426,6 +426,12 @@ object HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + // When we compare 2 decimal types with different precisions, cast them to the smallest // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1a55a0876f303..394ef556e04a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "min" override def prettyName: String = symbol } + +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { + + override def toString: String = s"pmod($left, $right)" + + override def symbol: String = "pmod" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "pmod") + + override def inputType: AbstractDataType = NumericType + + protected override def nullSafeEval(left: Any, right: Any) = + dataType match { + case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + dataType match { + case dt: DecimalType => + val decimalAdd = "$plus" + s""" + ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); + if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); + } else { + ${ev.primitive} = r; + } + """ + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + s""" + ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); + if (r < 0) { + ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + } else { + ${ev.primitive} = r; + } + """ + case _ => + s""" + ${ctx.javaType(dataType)} r = $eval1 % $eval2; + if (r < 0) { + ${ev.primitive} = (r + $eval2) % $eval2; + } else { + ${ev.primitive} = r; + } + """ + } + }) + } + + private def pmod(a: Int, n: Int): Int = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Long, n: Long): Long = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Byte, n: Byte): Byte = { + val r = a % n + if (r < 0) {((r + n) % n).toByte} else r.toByte + } + + private def pmod(a: Double, n: Double): Double = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Short, n: Short): Short = { + val r = a % n + if (r < 0) {((r + n) % n).toShort} else r.toShort + } + + private def pmod(a: Float, n: Float): Float = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Decimal, n: Decimal): Decimal = { + val r = a % n + if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6c93698f8017b..e7e5231d32c9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.Decimal - class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), Array(1.toByte, 2.toByte)) } + + test("pmod") { + testNumericDataTypes { convert => + val left = Literal(convert(7)) + val right = Literal(convert(3)) + checkEvaluation(Pmod(left, right), convert(1)) + checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) + checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + checkEvaluation(Pmod(-7, 3), 2) + checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) + checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) + checkEvaluation(Pmod(2L, Long.MaxValue), 2) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5119ee31d852d..c7deaca8437a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1371,6 +1371,23 @@ object functions { */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividendColName: String, divisorColName: String): Column = + pmod(Column(dividendColName), Column(divisorColName)) + /** * Returns the double value that is closest in value to the argument and * is equal to a mathematical integer. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6cebec95d2850..70bd78737f69c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } + + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") + checkAnswer( + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") + checkAnswer( + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) + checkAnswer( + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) + } } From 9716a727fb2d11380794549039e12e53c771e120 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 10:46:22 -0700 Subject: [PATCH 26/46] [Minor][SQL] Allow spaces in the beginning and ending of string for Interval This is a minor fixing for #7355 to allow spaces in the beginning and ending of string parsed to `Interval`. Author: Liang-Chi Hsieh Closes #7390 from viirya/fix_interval_string and squashes the following commits: 9eb6831 [Liang-Chi Hsieh] Use trim instead of modifying regex. 57861f7 [Liang-Chi Hsieh] Fix scala style. 815a9cb [Liang-Chi Hsieh] Slightly modify regex to allow spaces in the beginning and ending of string. --- .../main/java/org/apache/spark/unsafe/types/Interval.java | 1 + .../java/org/apache/spark/unsafe/types/IntervalSuite.java | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index eb7475e9df869..905ea0b7b878c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -62,6 +62,7 @@ public static Interval fromString(String s) { if (s == null) { return null; } + s = s.trim(); Matcher m = p.matcher(s); if (!m.matches() || s.equals("interval")) { return null; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 44a949a371f2b..1832d0bc65551 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -75,6 +75,12 @@ public void fromStringTest() { Interval result = new Interval(-5 * 12 + 23, 0); assertEquals(Interval.fromString(input), result); + input = "interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + + input = " interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + // Error cases input = "interval 3month 1 hour"; assertEquals(Interval.fromString(input), null); From 303c1201c468d360a5f600ce37b8bee75a77a0e6 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Wed, 15 Jul 2015 12:10:53 -0700 Subject: [PATCH 27/46] [SPARK-7555] [DOCS] Add doc for elastic net in ml-guide and mllib-guide jkbradley I put the elastic net under the **Algorithm guide** section. Also add the formula of elastic net in mllib-linear `mllib-linear-methods#regularizers`. dbtsai I left the code tab for you to add example code. Do you think it is the right place? Author: Shuo Xiang Closes #6504 from coderxiang/elasticnet and squashes the following commits: f6061ee [Shuo Xiang] typo 90a7c88 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elasticnet 0610a36 [Shuo Xiang] move out the elastic net to ml-linear-methods 8747190 [Shuo Xiang] merge master 706d3f7 [Shuo Xiang] add python code 9bc2b4c [Shuo Xiang] typo db32a60 [Shuo Xiang] java code sample aab3b3a [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elasticnet a0dae07 [Shuo Xiang] simplify code d8616fd [Shuo Xiang] Update the definition of elastic net. Add scala code; Mention Lasso and Ridge df5bd14 [Shuo Xiang] use wikipeida page in ml-linear-methods.md 78d9366 [Shuo Xiang] address comments 8ce37c2 [Shuo Xiang] Merge branch 'elasticnet' of github.com:coderxiang/spark into elasticnet 8f24848 [Shuo Xiang] Merge branch 'elastic-net-doc' of github.com:coderxiang/spark into elastic-net-doc 998d766 [Shuo Xiang] Merge branch 'elastic-net-doc' of github.com:coderxiang/spark into elastic-net-doc 89f10e4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elastic-net-doc 9262a72 [Shuo Xiang] update 7e07d12 [Shuo Xiang] update b32f21a [Shuo Xiang] add doc for elastic net in sparkml 937eef1 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' into elastic-net-doc 180b496 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' aa0717d [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test --- docs/ml-guide.md | 31 +++++++++ docs/ml-linear-methods.md | 129 +++++++++++++++++++++++++++++++++++ docs/mllib-linear-methods.md | 53 +++++++------- 3 files changed, 188 insertions(+), 25 deletions(-) create mode 100644 docs/ml-linear-methods.md diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c74cb1f1ef8ea..8c46adf256a9a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,6 +3,24 @@ layout: global title: Spark ML Programming Guide --- +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. @@ -154,6 +172,19 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +# Algorithm Guides + +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. + +**Pipelines API Algorithm Guides** + +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) + +**Algorithms in `spark.ml`** + +* [Linear methods with elastic net regularization](ml-linear-methods.html) + # Code Examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md new file mode 100644 index 0000000000000..1ac83d94c9e81 --- /dev/null +++ b/docs/ml-linear-methods.md @@ -0,0 +1,129 @@ +--- +layout: global +title: Linear Methods - ML +displayTitle: ML - Linear Methods +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +`\[ +\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\]` +By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. + +**Examples** + +
+ +
+ +{% highlight scala %} + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for logistic regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} + +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Logistic Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for logistic regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + } +} +{% endhighlight %} +
+ +
+ +{% highlight python %} + +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for logistic regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) +{% endhighlight %} + +
+ +
+ +### Optimization + +The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3927d65fbf8fb..07655baa414b5 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} @@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib:
+ + +
L1$\|\wv\|_1$$\mathrm{sign}(\wv)$
elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$
@@ -107,7 +110,7 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization @@ -531,7 +534,7 @@ sameModel = LogisticRegressionModel.load(sc, "myModelPath") ### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -539,8 +542,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is @@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -614,7 +617,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -634,7 +637,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -665,7 +668,7 @@ public class LinearRegression {
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -706,8 +709,8 @@ a dependency. ###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
-First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
From ec9b621647b893abae3afef219bceab382b99564 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 15 Jul 2015 12:15:35 -0700 Subject: [PATCH 28/46] SPARK-9070 JavaDataFrameSuite teardown NPEs if setup failed fix teardown to skip table delete if hive context is null Author: Steve Loughran Closes #7425 from steveloughran/stevel/patches/SPARK-9070-JavaDataFrameSuite-NPE and squashes the following commits: 1982d38 [Steve Loughran] SPARK-9070 JavaDataFrameSuite teardown NPEs if setup failed --- .../test/org/apache/spark/sql/hive/JavaDataFrameSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index c4828c4717643..741a3cd31c603 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -61,7 +61,9 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); + if (hc != null) { + hc.sql("DROP TABLE IF EXISTS window_table"); + } } @Test From 536533cad83a26f8fa7c60042904a31057ab56c2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 15 Jul 2015 13:32:25 -0700 Subject: [PATCH 29/46] [SPARK-9005] [MLLIB] Fix RegressionMetrics computation of explainedVariance Fixes implementation of `explainedVariance` and `r2` to be consistent with their definitions as described in [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005). Author: Feynman Liang Closes #7361 from feynmanliang/SPARK-9005-RegressionMetrics-bugs and squashes the following commits: f1112fc [Feynman Liang] Add explainedVariance formula 1a3d098 [Feynman Liang] SROwen code review comments 08a0e1b [Feynman Liang] Fix pyspark tests db8605a [Feynman Liang] Style fix bde9761 [Feynman Liang] Fix RegressionMetrics tests, relax assumption predictor is unbiased c235de0 [Feynman Liang] Fix RegressionMetrics tests 4c4e56f [Feynman Liang] Fix RegressionMetrics computation of explainedVariance and r2 --- .../mllib/evaluation/RegressionMetrics.scala | 27 +++++--- .../evaluation/RegressionMetricsSuite.scala | 69 +++++++++++++++++-- python/pyspark/mllib/evaluation.py | 2 +- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index e577bf87f885e..408847afa800d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend ) summary } + private lazy val SSerr = math.pow(summary.normL2(1), 2) + private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SSreg = { + val yMean = summary.mean(0) + predictionAndObservations.map { + case (prediction, _) => math.pow(prediction - yMean, 2) + }.sum() + } /** - * Returns the explained variance regression score. - * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) - * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * Returns the variance explained by regression. + * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n + * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ def explainedVariance: Double = { - 1 - summary.variance(1) / summary.variance(0) + SSreg / summary.count } /** @@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * expected value of the squared error loss or quadratic loss. */ def meanSquaredError: Double = { - val rmse = summary.normL2(1) / math.sqrt(summary.count) - rmse * rmse + SSerr / summary.count } /** @@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * the mean squared error. */ def rootMeanSquaredError: Double = { - summary.normL2(1) / math.sqrt(summary.count) + math.sqrt(this.meanSquaredError) } /** - * Returns R^2^, the coefficient of determination. - * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * Returns R^2^, the unadjusted coefficient of determination. + * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ def r2: Double = { - 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + 1 - SSerr / SStot } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 9de2bdb6d7246..4b7f1be58f99b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("regression metrics") { + test("regression metrics for unbiased (includes intercept term) predictor") { + /* Verify results in R: + preds = c(2.25, -0.25, 1.75, 7.75) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.796875 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.3125 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.559017 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9571734 + */ + val predictionAndObservations = sc.parallelize( + Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics for biased (no intercept term) predictor") { + /* Verify results in R: + preds = c(2.5, 0.0, 2.0, 8.0) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.859375 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.375 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.6123724 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9486081 + */ val predictionAndObservations = sc.parallelize( Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") } test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index f21403707e12a..4398ca86f2ec2 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper): ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) >>> metrics.explainedVariance - 0.95... + 8.859... >>> metrics.meanAbsoluteError 0.5... >>> metrics.meanSquaredError From b9a922e260bec1b211437f020be37fab46a85db0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 15 Jul 2015 14:02:23 -0700 Subject: [PATCH 30/46] [SPARK-6602][Core]Replace Akka Serialization with Spark Serializer Replace Akka Serialization with Spark Serializer and add unit tests. Author: zsxwing Closes #7159 from zsxwing/remove-akka-serialization and squashes the following commits: fc0fca3 [zsxwing] Merge branch 'master' into remove-akka-serialization cf81a58 [zsxwing] Fix the code style 73251c6 [zsxwing] Add test scope 9ef4af9 [zsxwing] Add AkkaRpcEndpointRef.hashCode 433115c [zsxwing] Remove final be3edb0 [zsxwing] Support deserializing RpcEndpointRef ecec410 [zsxwing] Replace Akka Serialization with Spark Serializer --- core/pom.xml | 5 + .../master/FileSystemPersistenceEngine.scala | 35 ++--- .../apache/spark/deploy/master/Master.scala | 18 +-- .../deploy/master/PersistenceEngine.scala | 8 +- .../deploy/master/RecoveryModeFactory.scala | 9 +- .../master/ZooKeeperPersistenceEngine.scala | 16 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 6 + .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 14 +- .../master/CustomRecoveryModeFactory.scala | 31 ++--- .../spark/deploy/master/MasterSuite.scala | 2 +- .../master/PersistenceEngineSuite.scala | 126 ++++++++++++++++++ pom.xml | 6 + 12 files changed, 214 insertions(+), 62 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 558cc3fb9f2f3..73f7a75cab9d3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -372,6 +372,11 @@ junit-interface test + + org.apache.curator + curator-test + test + net.razorvine pyrolite diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index f459ed5b3a1a1..aa379d4cd61e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,8 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} import org.apache.spark.util.Utils @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null Utils.tryWithSafeFinally { - out.write(serialized) + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) } { - out.close() + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 245b047e7dfbd..4615febf17d24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.language.postfixOps import scala.util.Random -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, @@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} @@ -58,9 +56,6 @@ private[master] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - // TODO Remove it once we don't use akka.serialization.Serialization - private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -161,20 +156,21 @@ private[master] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(actorSystem)) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -213,7 +209,7 @@ private[master] class Master( override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index a03d460509e03..58a00bceee6af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -80,8 +81,11 @@ abstract class PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 351db8fab2041..c4c3283fb73f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: } } -private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { def createPersistenceEngine(): PersistenceEngine = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 328d95a7a0c68..563831cc6b8dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine with Logging { @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index c9fcc7a36cc04..29debe8081308 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -139,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f2d87f68341af..fc17542abf81d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import com.google.common.util.concurrent.MoreExecutors +import akka.serialization.JavaSerializer import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ @@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef( override def toString: String = s"${getClass.getSimpleName}($actorRef)" + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index f4e56632e426a..8c96b0e71dfdd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -19,18 +19,19 @@ // when they are outside of org.apache.spark. package other.supplier +import java.nio.ByteBuffer + import scala.collection.mutable import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( conf: SparkConf, - serialization: Serialization -) extends StandaloneRecoveryModeFactory(conf, serialization) { + serializer: Serializer +) extends StandaloneRecoveryModeFactory(conf, serializer) { CustomRecoveryModeFactory.instantiationAttempts += 1 @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory( * */ override def createPersistenceEngine(): PersistenceEngine = - new CustomPersistenceEngine(serialization) + new CustomPersistenceEngine(serializer) /** * Create an instance of LeaderAgent that decides who gets elected as master. @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory { @volatile var instantiationAttempts = 0 } -class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine { +class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine { val data = mutable.HashMap[String, Array[Byte]]() CustomPersistenceEngine.lastInstance = Some(this) @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - serialization.serialize(obj) match { - case util.Success(bytes) => data += name -> bytes - case util.Failure(cause) => throw new RuntimeException(cause) - } + val serialized = serializer.newInstance().serialize(obj) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + data += name -> bytes } /** @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serialization.deserialize(bytes, clazz) - - results.find(_.isFailure).foreach { - case util.Failure(cause) => throw new RuntimeException(cause) - } - - results.flatMap(_.toOption).toSeq + yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + results.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 9cb6dd43bac47..a8fbaf1d9da0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { persistenceEngine.addDriver(driverToPersist) persistenceEngine.addWorker(workerToPersist) - val (apps, drivers, workers) = persistenceEngine.readPersistedData() + val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv) apps.map(_.id) should contain(appToPersist.id) drivers.map(_.id) should contain(driverToPersist.id) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala new file mode 100644 index 0000000000000..11e87bd1dd8eb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.master + +import java.net.ServerSocket + +import org.apache.commons.lang3.RandomUtils +import org.apache.curator.test.TestingServer + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.util.Utils + +class PersistenceEngineSuite extends SparkFunSuite { + + test("FileSystemPersistenceEngine") { + val dir = Utils.createTempDir() + try { + val conf = new SparkConf() + testPersistenceEngine(conf, serializer => + new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) + ) + } finally { + Utils.deleteRecursively(dir) + } + } + + test("ZooKeeperPersistenceEngine") { + val conf = new SparkConf() + // TestingServer logs the port conflict exception rather than throwing an exception. + // So we have to find a free port by ourselves. This approach cannot guarantee always starting + // zkTestServer successfully because there is a time gap between finding a free port and + // starting zkTestServer. But the failure possibility should be very low. + val zkTestServer = new TestingServer(findFreePort(conf)) + try { + testPersistenceEngine(conf, serializer => { + conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString) + new ZooKeeperPersistenceEngine(conf, serializer) + }) + } finally { + zkTestServer.stop() + } + } + + private def testPersistenceEngine( + conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { + val serializer = new JavaSerializer(conf) + val persistenceEngine = persistenceEngineCreator(serializer) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = rpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + rpcEnv.shutdown() + rpcEnv.awaitTermination() + } + } + + private def findFreePort(conf: SparkConf): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, conf)._2 + } +} diff --git a/pom.xml b/pom.xml index 370c95dd03632..aa49e2ab7294b 100644 --- a/pom.xml +++ b/pom.xml @@ -748,6 +748,12 @@ curator-framework ${curator.version} + + org.apache.curator + curator-test + ${curator.version} + test + org.apache.hadoop hadoop-client From 674eb2a4c3ff595760f990daf369ba75d2547593 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei Date: Wed, 15 Jul 2015 22:31:10 +0100 Subject: [PATCH 31/46] [SPARK-8974] Catch exceptions in allocation schedule task. I meet a problem. When I submit some tasks, the thread spark-dynamic-executor-allocation should seed the message about "requestTotalExecutors", and the new executor should start. But I meet a problem about this thread, like: 2015-07-14 19:02:17,461 | WARN | [spark-dynamic-executor-allocation] | Error sending message [message = RequestExecutors(1)] in 1 attempts java.util.concurrent.TimeoutException: Futures timed out after [120 seconds] at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:219) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:107) at scala.concurrent.BlockContext$DefaultBlockContext$.blockOn(BlockContext.scala:53) at scala.concurrent.Await$.result(package.scala:107) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:102) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:78) at org.apache.spark.scheduler.cluster.YarnSchedulerBackend.doRequestTotalExecutors(YarnSchedulerBackend.scala:57) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.requestTotalExecutors(CoarseGrainedSchedulerBackend.scala:351) at org.apache.spark.SparkContext.requestTotalExecutors(SparkContext.scala:1382) at org.apache.spark.ExecutorAllocationManager.addExecutors(ExecutorAllocationManager.scala:343) at org.apache.spark.ExecutorAllocationManager.updateAndSyncNumExecutorsTarget(ExecutorAllocationManager.scala:295) at org.apache.spark.ExecutorAllocationManager.org$apache$spark$ExecutorAllocationManager$$schedule(ExecutorAllocationManager.scala:248) when after some minutes, I find a new ApplicationMaster start, and tasks submitted start to run. The tasks Completed. And after long time (eg, ten minutes), the number of executor does not reduce to zero. I use the default value of "spark.dynamicAllocation.minExecutors". Author: KaiXinXiaoLei Closes #7352 from KaiXinXiaoLei/dym and squashes the following commits: 3603631 [KaiXinXiaoLei] change logError to logWarning efc4f24 [KaiXinXiaoLei] change file --- .../org/apache/spark/ExecutorAllocationManager.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 0c50b4002cf7b..648bcfe28cad2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -211,7 +212,16 @@ private[spark] class ExecutorAllocationManager( listenerBus.addListener(listener) val scheduleTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + override def run(): Unit = { + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } From affbe329ae0100bd50a3c3fb081b0f2b07efce33 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 15 Jul 2015 14:52:02 -0700 Subject: [PATCH 32/46] [SPARK-9071][SQL] MonotonicallyIncreasingID and SparkPartitionID should be marked as nondeterministic. I also took the chance to more explicitly define the semantics of deterministic. Author: Reynold Xin Closes #7428 from rxin/non-deterministic and squashes the following commits: a760827 [Reynold Xin] [SPARK-9071][SQL] MonotonicallyIncreasingID and SparkPartitionID should be marked as nondeterministic. --- .../spark/sql/catalyst/expressions/Expression.scala | 10 ++++++++-- .../expressions/MonotonicallyIncreasingID.scala | 4 +++- .../sql/execution/expressions/SparkPartitionID.scala | 4 +++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3f19ac2b592b5..7b37ae7335253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -61,9 +61,15 @@ abstract class Expression extends TreeNode[Expression] { def foldable: Boolean = false /** - * Returns true when the current expression always return the same result for fixed input values. + * Returns true when the current expression always return the same result for fixed inputs from + * children. + * + * Note that this means that an expression should be considered as non-deterministic if: + * - if it relies on some mutable internal state, or + * - if it relies on some implicit input that is not part of the children expression list. + * + * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. */ - // TODO: Need to define explicit input values vs implicit input values. def deterministic: Boolean = true def nullable: Boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 69a37750d7525..fec403fe2d348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -41,7 +41,9 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - @transient private lazy val partitionMask = TaskContext.getPartitionId.toLong << 33 + @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + + override def deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 5f1b514f2cff2..7c790c549a5d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -29,11 +29,13 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId + @transient private lazy val partitionId = TaskContext.getPartitionId() override def eval(input: InternalRow): Int = partitionId From b0645195d0da57065885e078e08bd6c42f4f19b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 15 Jul 2015 17:50:11 -0700 Subject: [PATCH 33/46] [SPARK-9086][SQL] Remove BinaryNode from TreeNode. These traits are not super useful, and yet cause problems with toString in expressions due to the orders they are mixed in. Author: Reynold Xin Closes #7433 from rxin/remove-binary-node and squashes the following commits: 1881f78 [Reynold Xin] [SPARK-9086][SQL] Remove BinaryNode from TreeNode. --- .../sql/catalyst/expressions/Expression.scala | 17 ++++++++++++++--- .../catalyst/plans/logical/LogicalPlan.scala | 7 ++++++- .../spark/sql/catalyst/trees/TreeNode.scala | 9 --------- .../apache/spark/sql/execution/SparkPlan.scala | 7 ++++++- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7b37ae7335253..87667316aca67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -187,8 +187,10 @@ abstract class Expression extends TreeNode[Expression] { /** * A leaf expression, i.e. one without any child expressions. */ -abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { +abstract class LeafExpression extends Expression { self: Product => + + def children: Seq[Expression] = Nil } @@ -196,9 +198,13 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] * An expression with one input and one output. The output is by default evaluated to null * if the input is evaluated to null. */ -abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { +abstract class UnaryExpression extends Expression { self: Product => + def child: Expression + + override def children: Seq[Expression] = child :: Nil + override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -271,9 +277,14 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { +abstract class BinaryExpression extends Expression { self: Product => + def left: Expression + def right: Expression + + override def children: Seq[Expression] = Seq(left, right) + override def foldable: Boolean = left.foldable && right.foldable override def nullable: Boolean = left.nullable || right.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e911b907e8536..d7077a0ec907a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -291,6 +291,11 @@ abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { /** * A logical plan node with a left and right child. */ -abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { +abstract class BinaryNode extends LogicalPlan { self: Product => + + def left: LogicalPlan + def right: LogicalPlan + + override def children: Seq[LogicalPlan] = Seq(left, right) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 09f6c6b0ec423..16844b2f4b680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -453,15 +453,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } } -/** - * A [[TreeNode]] that has two children, [[left]] and [[right]]. - */ -trait BinaryNode[BaseType <: TreeNode[BaseType]] { - def left: BaseType - def right: BaseType - - def children: Seq[BaseType] = Seq(left, right) -} /** * A [[TreeNode]] with no children. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4d7d8626a0ecc..9dc7879fa4a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -247,6 +247,11 @@ private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { override def outputPartitioning: Partitioning = child.outputPartitioning } -private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { +private[sql] trait BinaryNode extends SparkPlan { self: Product => + + def left: SparkPlan + def right: SparkPlan + + override def children: Seq[SparkPlan] = Seq(left, right) } From 6960a7938c61cc07f181ca85e0d8152ceeb453d9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 15 Jul 2015 20:33:06 -0700 Subject: [PATCH 34/46] [SPARK-8774] [ML] Add R model formula with basic support as a transformer This implements minimal R formula support as a feature transformer. Both numeric and string labels are supported, but features must be numeric for now. cc mengxr Author: Eric Liang Closes #7381 from ericl/spark-8774-1 and squashes the following commits: d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer --- .../apache/spark/ml/feature/RFormula.scala | 151 ++++++++++++++++++ .../spark/ml/feature/VectorAssembler.scala | 2 +- .../ml/feature/RFormulaParserSuite.scala | 34 ++++ .../spark/ml/feature/RFormulaSuite.scala | 93 +++++++++++ 4 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala new file mode 100644 index 0000000000000..d9a36bda386b3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the transforms required for fitting a dataset against an R model formula. Currently + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +@Experimental +class RFormula(override val uid: String) + extends Transformer with HasFeaturesCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("rFormula")) + + /** + * R formula parameter. The formula is provided in string form. + * @group setParam + */ + val formula: Param[String] = new Param(this, "formula", "R model formula") + + private var parsedFormula: Option[ParsedRFormula] = None + + /** + * Sets the formula to use for this transformer. Must be called before use. + * @group setParam + * @param value an R formula in string form (e.g. "y ~ x + z") + */ + def setFormula(value: String): this.type = { + parsedFormula = Some(RFormulaParser.parse(value)) + set(formula, value) + this + } + + /** @group getParam */ + def getFormula: String = $(formula) + + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transformSchema(schema: StructType): StructType = { + checkCanTransform(schema) + val withFeatures = transformFeatures.transformSchema(schema) + if (hasLabelCol(schema)) { + withFeatures + } else { + val nullable = schema(parsedFormula.get.label).dataType match { + case _: NumericType | BooleanType => false + case _ => true + } + StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } + } + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(transformFeatures.transform(dataset)) + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" + + private def transformLabel(dataset: DataFrame): DataFrame = { + if (hasLabelCol(dataset.schema)) { + dataset + } else { + val labelName = parsedFormula.get.label + dataset.schema(labelName).dataType match { + case _: NumericType | BooleanType => + dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) + // TODO(ekl) add support for string-type labels + case other => + throw new IllegalArgumentException("Unsupported type for label: " + other) + } + } + } + + private def transformFeatures: Transformer = { + // TODO(ekl) add support for non-numeric features and feature interactions + new VectorAssembler(uid) + .setInputCols(parsedFormula.get.terms.toArray) + .setOutputCol($(featuresCol)) + } + + private def checkCanTransform(schema: StructType) { + require(parsedFormula.isDefined, "Must call setFormula() first.") + val columnNames = schema.map(_.name) + require(!columnNames.contains($(featuresCol)), "Features column already exists.") + require( + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, + "Label column already exists and is not of type DoubleType.") + } + + private def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r + + def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } + + def formula: Parser[ParsedRFormula] = + (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 9f83c2ee16178..086917fa680f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String) if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } - StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala new file mode 100644 index 0000000000000..c8d065f37a605 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite + +class RFormulaParserSuite extends SparkFunSuite { + private def checkParse(formula: String, label: String, terms: Seq[String]) { + val parsed = RFormulaParser.parse(formula) + assert(parsed.label == label) + assert(parsed.terms == terms) + } + + test("parse simple formulas") { + checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala new file mode 100644 index 0000000000000..fa8611b243a9f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RFormula()) + } + + test("transform numeric data") { + val formula = new RFormula().setFormula("id ~ v1 + v2") + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val result = formula.transform(original) + val resultSchema = formula.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), + (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + ).toDF("id", "v1", "v2", "features", "label") + // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("features column already exists") { + val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + + test("label column already exists") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + + test("label column already exists but is not double type") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + +// TODO(ekl) enable after we implement string label support +// test("transform string label") { +// val formula = new RFormula().setFormula("name ~ id") +// val original = sqlContext.createDataFrame( +// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") +// val result = formula.transform(original) +// val resultSchema = formula.transformSchema(original.schema) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, "foo", Vectors.dense(Array(1.0)), 1.0), +// (2, "bar", Vectors.dense(Array(2.0)), 0.0), +// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) +// ).toDF("id", "name", "features", "label") +// assert(result.schema.toString == resultSchema.toString) +// assert(result.collect().toSeq == expected.collect().toSeq) +// } +} From 73d92b00b9a6f5dfc2f8116447d17b381cd74f80 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 15 Jul 2015 21:02:42 -0700 Subject: [PATCH 35/46] [SPARK-9018] [MLLIB] add stopwatches Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley Author: Xiangrui Meng Closes #7415 from mengxr/SPARK-9018 and squashes the following commits: 40b4347 [Xiangrui Meng] == -> === c477745 [Xiangrui Meng] address Joseph's comments f981a49 [Xiangrui Meng] add stopwatches --- .../apache/spark/ml/util/stopwatches.scala | 151 ++++++++++++++++++ .../apache/spark/ml/util/StopwatchSuite.scala | 109 +++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala new file mode 100644 index 0000000000000..5fdf878a3df72 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.{Accumulator, SparkContext} + +/** + * Abstract class for stopwatches. + */ +private[spark] abstract class Stopwatch extends Serializable { + + @transient private var running: Boolean = false + private var startTime: Long = _ + + /** + * Name of the stopwatch. + */ + val name: String + + /** + * Starts the stopwatch. + * Throws an exception if the stopwatch is already running. + */ + def start(): Unit = { + assume(!running, "start() called but the stopwatch is already running.") + running = true + startTime = now + } + + /** + * Stops the stopwatch and returns the duration of the last session in milliseconds. + * Throws an exception if the stopwatch is not running. + */ + def stop(): Long = { + assume(running, "stop() called but the stopwatch is not running.") + val duration = now - startTime + add(duration) + running = false + duration + } + + /** + * Checks whether the stopwatch is running. + */ + def isRunning: Boolean = running + + /** + * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch + * is running. + */ + def elapsed(): Long + + /** + * Gets the current time in milliseconds. + */ + protected def now: Long = System.currentTimeMillis() + + /** + * Adds input duration to total elapsed time. + */ + protected def add(duration: Long): Unit +} + +/** + * A local [[Stopwatch]]. + */ +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch { + + private var elapsedTime: Long = 0L + + override def elapsed(): Long = elapsedTime + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A distributed [[Stopwatch]] using Spark accumulator. + * @param sc SparkContext + */ +private[spark] class DistributedStopwatch( + sc: SparkContext, + override val name: String) extends Stopwatch { + + private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + + override def elapsed(): Long = elapsedTime.value + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A multiple stopwatch that contains local and distributed stopwatches. + * @param sc SparkContext + */ +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable { + + private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty + + /** + * Adds a local stopwatch. + * @param name stopwatch name + */ + def addLocal(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new LocalStopwatch(name) + this + } + + /** + * Adds a distributed stopwatch. + * @param name stopwatch name + */ + def addDistributed(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new DistributedStopwatch(sc, name) + this + } + + /** + * Gets a stopwatch. + * @param name stopwatch name + */ + def apply(name: String): Stopwatch = stopwatches(name) + + override def toString: String = { + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" ${c.name}: ${c.elapsed()}ms") + .mkString("{\n", ",\n", "\n}") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala new file mode 100644 index 0000000000000..8df6617fe0228 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { + assert(sw.name === "sw") + assert(sw.elapsed() === 0L) + assert(!sw.isRunning) + intercept[AssertionError] { + sw.stop() + } + sw.start() + Thread.sleep(50) + val duration = sw.stop() + assert(duration >= 50 && duration < 100) // using a loose upper bound + val elapsed = sw.elapsed() + assert(elapsed === duration) + sw.start() + Thread.sleep(50) + val duration2 = sw.stop() + assert(duration2 >= 50 && duration2 < 100) + val elapsed2 = sw.elapsed() + assert(elapsed2 === duration + duration2) + sw.start() + assert(sw.isRunning) + intercept[AssertionError] { + sw.start() + } + } + + test("LocalStopwatch") { + val sw = new LocalStopwatch("sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on driver") { + val sw = new DistributedStopwatch(sc, "sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on executors") { + val sw = new DistributedStopwatch(sc, "sw") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw.start() + Thread.sleep(50) + sw.stop() + } + assert(!sw.isRunning) + val elapsed = sw.elapsed() + assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + } + + test("MultiStopwatch") { + val sw = new MultiStopwatch(sc) + .addLocal("local") + .addDistributed("spark") + assert(sw("local").name === "local") + assert(sw("spark").name === "spark") + intercept[NoSuchElementException] { + sw("some") + } + assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("local").stop() + Thread.sleep(50) + sw("spark").stop() + val localElapsed = sw("local").elapsed() + val sparkElapsed = sw("spark").elapsed() + assert(localElapsed >= 50 && localElapsed < 100) + assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(sw.toString === + s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("spark").stop() + sw("local").stop() + } + val localElapsed2 = sw("local").elapsed() + assert(localElapsed2 === localElapsed) + val sparkElapsed2 = sw("spark").elapsed() + assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + } +} From 9c64a75bfc5e2566d1b4cd0d9b4585a818086ca6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 15 Jul 2015 21:08:30 -0700 Subject: [PATCH 36/46] [SPARK-9060] [SQL] Revert SPARK-8359, SPARK-8800, and SPARK-8677 JIRA: https://issues.apache.org/jira/browse/SPARK-9060 This PR reverts: * https://github.com/apache/spark/commit/31bd30687bc29c0e457c37308d489ae2b6e5b72a (SPARK-8359) * https://github.com/apache/spark/commit/24fda7381171738cbbbacb5965393b660763e562 (SPARK-8677) * https://github.com/apache/spark/commit/4b5cfc988f23988c2334882a255d494fc93d252e (SPARK-8800) Author: Yin Huai Closes #7426 from yhuai/SPARK-9060 and squashes the following commits: 651264d [Yin Huai] Revert "[SPARK-8359] [SQL] Fix incorrect decimal precision after multiplication" cfda7e4 [Yin Huai] Revert "[SPARK-8677] [SQL] Fix non-terminating decimal expansion for decimal divide operation" 2de9afe [Yin Huai] Revert "[SPARK-8800] [SQL] Fix inaccurate precision/scale of Decimal division operation" --- .../org/apache/spark/sql/types/Decimal.scala | 21 ++----------------- .../sql/types/decimal/DecimalSuite.scala | 18 ---------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f5bd068d60dc4..a85af9e04aedb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} - import org.apache.spark.annotation.DeveloperApi /** @@ -138,14 +136,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { } def toBigDecimal: BigDecimal = { - if (decimalVal.ne(null)) { - decimalVal(MathContext.UNLIMITED) - } else { - BigDecimal(longVal, _scale)(MathContext.UNLIMITED) - } - } - - def toLimitedBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { @@ -273,15 +263,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = { - if (that.isZero) { - null - } else { - // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited - // precision and scala. - Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal) - } - } + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index f0c849d1a1564..1d297beb3868d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -171,22 +171,4 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - - test("accurate precision after multiplication") { - val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal - assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") - } - - test("fix non-terminating decimal expansion problem") { - val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) - // The difference between decimal should not be more than 0.001. - assert(decimal.toDouble - 0.333 < 0.001) - } - - test("fix loss of precision/scale when doing division operation") { - val a = Decimal(2) / Decimal(3) - assert(a.toDouble < 1.0 && a.toDouble > 0.6) - val b = Decimal(1) / Decimal(8) - assert(b.toDouble === 0.125) - } } From 42dea3acf90ec506a0b79720b55ae1d753cc7544 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 21:47:21 -0700 Subject: [PATCH 37/46] [SPARK-8245][SQL] FormatNumber/Length Support for Expression - `BinaryType` for `Length` - `FormatNumber` Author: Cheng Hao Closes #7034 from chenghao-intel/expression and squashes the following commits: e534b87 [Cheng Hao] python api style issue 601bbf5 [Cheng Hao] add python API support 3ebe288 [Cheng Hao] update as feedback 52274f7 [Cheng Hao] add support for udf_format_number and length for binary --- python/pyspark/sql/functions.py | 25 ++++- .../catalyst/analysis/FunctionRegistry.scala | 5 +- .../expressions/stringOperations.scala | 94 +++++++++++++++++-- .../expressions/StringFunctionsSuite.scala | 53 +++++++++-- .../org/apache/spark/sql/functions.scala | 32 ++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 93 +++++++++++++++--- 6 files changed, 261 insertions(+), 41 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dca39fa833435..e0816b3e654bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,6 +39,8 @@ 'coalesce', 'countDistinct', 'explode', + 'format_number', + 'length', 'log2', 'md5', 'monotonicallyIncreasingId', @@ -47,7 +49,6 @@ 'sha1', 'sha2', 'sparkPartitionId', - 'strlen', 'struct', 'udf', 'when'] @@ -506,14 +507,28 @@ def sparkPartitionId(): @ignore_unicode_prefix @since(1.5) -def strlen(col): - """Calculates the length of a string expression. +def length(col): + """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.strlen(_to_java_column(col))) + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + and returns the result as a string. + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d2678ce860701..e0beafe710079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -152,11 +152,12 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), - expression[StringInstr]("instr"), + expression[FormatNumber]("format_number"), expression[Lower]("lcase"), expression[Lower]("lower"), - expression[StringLength]("length"), + expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 03b55ce5fe7cc..c64afe7b3f19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.DecimalFormat import java.util.Locale import java.util.regex.Pattern -import org.apache.commons.lang3.StringUtils - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string expression. + * A function that return the length of the given string or binary expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - protected override def nullSafeEval(string: Any): Any = - string.asInstanceOf[UTF8String].numChars + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numChars + case BinaryType => value.asInstanceOf[Array[Byte]].length + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).numChars()") + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } } override def prettyName: String = "length" @@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression) } } +/** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + */ +case class FormatNumber(x: Expression, d: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = x + override def right: Expression = d + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. + @transient + private var lastDValue: Int = -100 + + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. + @transient + private val pattern: StringBuffer = new StringBuffer() + + @transient + private val numberFormat: DecimalFormat = new DecimalFormat("") + + override def eval(input: InternalRow): Any = { + val xObject = x.eval(input) + if (xObject == null) { + return null + } + + val dObject = d.eval(input) + + if (dObject == null || dObject.asInstanceOf[Int] < 0) { + return null + } + val dValue = dObject.asInstanceOf[Int] + + if (dValue != lastDValue) { + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length()) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + val dFormat = new DecimalFormat(pattern.toString()) + lastDValue = dValue; + numberFormat.applyPattern(dFormat.toPattern()) + } + + x.dataType match { + case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) + case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) + case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) + case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) + case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) + case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) + case _: DecimalType => + UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) + } + } + + override def prettyName: String = "format_number" +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index b19f4ee37a109..5d7763bedf6bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.types._ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("length for string") { - val a = 'a.string.at(0) - checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(a), 5, create_row("abdef")) - checkEvaluation(StringLength(a), 0, create_row("")) - checkEvaluation(StringLength(a), null, create_row(null)) - checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) } + + test("length for string / binary") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 1, 2) + val string = "abdef" + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + // scalastyle:on + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + + checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(Length(b), 5, create_row(bytes)) + + checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + + checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(Length(b), null, create_row(null)) + + checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + } + + test("number format") { + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null) + checkEvaluation( + FormatNumber( + Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), + "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) + checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c7deaca8437a1..d6da284a4c788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1685,20 +1685,44 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value. + * Computes the length of a given string / binary value. * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string column. + * Computes the length of a given string / binary column. * * @group string_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) + def length(columnName: String): Column = length(Column(columnName)) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def format_number(columnXName: String, d: Int): Column = { + format_number(Column(columnXName), d) + } /** * Computes the Levenshtein distance of the two given strings. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 70bd78737f69c..6dccdd857b453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("string length function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(strlen($"a"), strlen("b")), - Row(3, 0)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 0)) - } - test("Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) @@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest { val doubleData = Seq((7.2, 4.1)).toDF("a", "b") checkAnswer( doubleData.select(pmod('a, 'b)), - Seq(Row(3.1000000000000005)) // same as hive + Seq(Row(3.1000000000000005)) // same as hive ) checkAnswer( doubleData.select(pmod(lit(2), lit(Int.MaxValue))), Seq(Row(2)) ) } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + format_number($"f", 4), + format_number("f", 4)), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } } From ba33096846dc8061e97a7bf8f3b46f899d530159 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 22:27:39 -0700 Subject: [PATCH 38/46] [SPARK-9068][SQL] refactor the implicit type cast code based on https://github.com/apache/spark/pull/7348 Author: Wenchen Fan Closes #7420 from cloud-fan/type-check and squashes the following commits: 7633fa9 [Wenchen Fan] revert fe169b0 [Wenchen Fan] improve test 03b70da [Wenchen Fan] enhance implicit type cast --- .../catalyst/analysis/HiveTypeCoercion.scala | 33 +++----- .../sql/catalyst/expressions/Expression.scala | 20 +++-- .../sql/catalyst/expressions/arithmetic.scala | 2 - .../sql/catalyst/expressions/bitwise.scala | 8 +- .../catalyst/expressions/conditionals.scala | 4 +- .../spark/sql/types/AbstractDataType.scala | 45 +++-------- .../apache/spark/sql/types/ArrayType.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 2 +- .../org/apache/spark/sql/types/MapType.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 75 +++++++++---------- .../analysis/HiveTypeCoercionSuite.scala | 10 +-- 13 files changed, 81 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 25087915b5c35..50db7d21f01ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -675,10 +675,10 @@ object HiveTypeCoercion { case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tighest common type, cast to that. + // If the expression accepts the tightest common type, cast to that. val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.makeCopy(Array(newLeft, newRight)) + b.withNewChildren(Seq(newLeft, newRight)) } else { // Otherwise, don't do anything with the expression. b @@ -697,7 +697,7 @@ object HiveTypeCoercion { // general implicit casting. val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => if (in.dataType == NullType && !expected.acceptsType(NullType)) { - Cast(in, expected.defaultConcreteType) + Literal.create(null, expected.defaultConcreteType) } else { in } @@ -719,27 +719,22 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isSameType(inType) => e + case _ if expectedType.acceptsType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) - // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is - // already a number, leave it as is. - case (_: NumericType, NumericType) => e - // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) - // Implicit cast among numeric types + // Implicit cast among numeric types. When we reach here, input type is not acceptable. + // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => - Cast(e, DecimalType.Unlimited) + case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) - case (_: NumericType, target: NumericType) => e + case (_: NumericType, target: NumericType) => Cast(e, target) // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) @@ -753,15 +748,9 @@ object HiveTypeCoercion { case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) - // Type collection. - // First see if we can find our input type in the type collection. If we can, then just - // use the current expression; otherwise, find the first one we can implicitly cast. - case (_, TypeCollection(types)) => - if (types.exists(_.isSameType(inType))) { - e - } else { - types.flatMap(implicitCast(e, _)).headOption.orNull - } + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull // Else, just return the same input expression case _ => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 87667316aca67..a655cc8e48ae1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) override def checkInputDataTypes(): TypeCheckResult = { - // First call the checker for ExpectsInputTypes, and then check whether left and right have - // the same type. - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + - s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") - } else { - TypeCheckResult.TypeCheckSuccess - } - case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg) + // First check whether left and right have the same type, then check if the type is acceptable. + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else if (!inputType.acceptsType(left.dataType)) { + TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + + s" not ${left.dataType.simpleString}") + } else { + TypeCheckResult.TypeCheckSuccess } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 394ef556e04a2..382cbe3b84a07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "max" - override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "min" - override def prettyName: String = symbol } case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index af1abbcd2239b..a1e48c4210877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "&" @@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "|" @@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.Bitwise + override def inputType: AbstractDataType = IntegralType override def symbol: String = "^" @@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme */ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise) + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index c7f039ede26b3..9162b73fe56eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index f5715f7a829ff..076d7b5a5118d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is the same type as `other`. This is different that equality - * as equality will also consider data type parametrization, such as decimal precision. + * Returns true if `other` is an acceptable input type for a function that expects this, + * possibly abstract DataType. * * {{{ * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) - * - * // this should return false - * NumericType.isSameType(DecimalType(10, 2)) - * }}} - */ - private[sql] def isSameType(other: DataType): Boolean - - /** - * Returns true if `other` is an acceptable input type for a function that expectes this, - * possibly abstract, DataType. - * - * {{{ - * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) + * DecimalType.acceptsType(DecimalType(10, 2)) * * // this should return true as well * NumericType.acceptsType(DecimalType(10, 2)) * }}} */ - private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) + private[sql] def acceptsType(other: DataType): Boolean /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = - types.exists(_.isSameType(other)) + types.exists(_.acceptsType(other)) override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") @@ -107,13 +91,6 @@ private[sql] object TypeCollection { TimestampType, DateType, StringType, BinaryType) - /** - * Types that can be used in bitwise operations. - */ - val Bitwise = TypeCollection( - BooleanType, - ByteType, ShortType, IntegerType, LongType) - def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType { override private[sql] def simpleString: String = "any" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = true } @@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType { override private[sql] def simpleString: String = "numeric" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } -private[sql] object IntegralType { +private[sql] object IntegralType extends AbstractDataType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -198,6 +171,12 @@ private[sql] object IntegralType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + override private[sql] def defaultConcreteType: DataType = IntegerType + + override private[sql] def simpleString: String = "integral" + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 76ca7a84c1d1a..5094058164b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index da83a7f0ba379..2d133eea19fe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def isSameType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index a1cafeab1704d..377c75f6e85a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = Unlimited - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ddead10bc2171..ac34b642827ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,7 +71,7 @@ object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[MapType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b8097403ec3cc..2ef97a427c37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,7 +307,7 @@ object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[StructType] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index a4ce1825cab28..ed0d20e7de80e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -49,23 +49,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.prettyString}' (int and boolean)") - } - - def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - assertSuccess(expr) - } - assert(e.getMessage.contains(errorMessage)) + s"differing types in '${expr.prettyString}'") } test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "expected to be of type numeric") assertError(Abs('stringField), "expected to be of type numeric") - assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)") + assertError(BitwiseNot('stringField), "expected to be of type integral") } - ignore("check types for binary arithmetic") { + test("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic assertSuccess(Add('intField, 'stringField)) assertSuccess(Subtract('intField, 'stringField)) @@ -85,21 +78,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") - assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") - assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type") - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + assertError(MaxOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(MinOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") } - ignore("check types for predicates") { + test("check types for predicates") { // We will cast String to Double for binary comparison assertSuccess(EqualTo('intField, 'stringField)) assertSuccess(EqualNullSafe('intField, 'stringField)) @@ -112,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertError(EqualTo('intField, 'complexField), "differing types") - assertError(EqualNullSafe('intField, 'complexField), "differing types") - + assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + assertError(LessThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(LessThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") - assertError( - If('intField, 'stringField, 'stringField), + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) @@ -180,12 +173,12 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for ROUND") { - assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), - "data type mismatch: argument 2 is expected to be of type int") - assertErrorWithImplicitCast(Round(Literal(null), 'complexField), - "data type mismatch: argument 2 is expected to be of type int") assertSuccess(Round(Literal(null), Literal(null))) - assertError(Round('booleanField, 'intField), - "data type mismatch: argument 1 is expected to be of type numeric") + assertSuccess(Round('intField, Literal(1))) + + assertError(Round('intField, 'intField), "Only foldable Expression is allowed") + assertError(Round('intField, 'booleanField), "expected to be of type int") + assertError(Round('intField, 'complexField), "expected to be of type int") + assertError(Round('booleanField, 'intField), "expected to be of type numeric") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 8e9b20a3ebe42..d0fd033b981c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -203,7 +203,7 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(HiveTypeCoercion.ImplicitTypeCasts, NumericTypeUnaryExpression(Literal.create(null, NullType)), - NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType))) + NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } test("cast NullType for binary operators") { @@ -215,9 +215,7 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(HiveTypeCoercion.ImplicitTypeCasts, NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), - NumericTypeBinaryOperator( - Cast(Literal.create(null, NullType), DoubleType), - Cast(Literal.create(null, NullType), DoubleType))) + NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { @@ -345,14 +343,14 @@ object HiveTypeCoercionSuite { } case class AnyTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator with ExpectsInputTypes { + extends BinaryOperator { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" } case class NumericTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator with ExpectsInputTypes { + extends BinaryOperator { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype" From e27212317c7341852c52d9a85137b8f94cb0d935 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 15 Jul 2015 23:35:27 -0700 Subject: [PATCH 39/46] [SPARK-8972] [SQL] Incorrect result for rollup We don't support the complex expression keys in the rollup/cube, and we even will not report it if we have the complex group by keys, that will cause very confusing/incorrect result. e.g. `SELECT key%100 FROM src GROUP BY key %100 with ROLLUP` This PR adds an additional project during the analyzing for the complex GROUP BY keys, and that projection will be the child of `Expand`, so to `Expand`, the GROUP BY KEY are always the simple key(attribute names). Author: Cheng Hao Closes #7343 from chenghao-intel/expand and squashes the following commits: 1ebbb59 [Cheng Hao] update the comment 827873f [Cheng Hao] update as feedback 34def69 [Cheng Hao] Add more unit test and comments c695760 [Cheng Hao] fix bug of incorrect result for rollup --- .../sql/catalyst/analysis/Analyzer.scala | 42 +++++++++++++-- ...CUBE #1-0-63b61fb3f0e74226001ad279be440864 | 6 +++ ...CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 | 10 ++++ ...pingSet-0-8c14c24670a4b06c440346277ce9cf1c | 10 ++++ ...llup #1-0-a78e3dbf242f240249e36b3d3fd0926a | 6 +++ ...llup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 | 10 ++++ ...llup #3-0-9257085d123728730be96b6d9fbb84ce | 10 ++++ .../sql/hive/execution/HiveQuerySuite.scala | 54 +++++++++++++++++++ 8 files changed, 145 insertions(+), 3 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 create mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 create mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c create mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a create mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 create mode 100644 sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 891408e310049..df8e7f2381fbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -194,16 +194,52 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a if !a.childrenResolved => a // be sure all of the children are resolved. case a: Cube => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case a: Rollup => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() + // We will insert another Projection if the GROUP BY keys contains the + // non-attribute expressions. And the top operators can references those + // expressions by its alias. + // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> + // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a + + // find all of the non-attribute expressions in the GROUP BY keys + val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() + + // The pair of (the original GROUP BY key, associated attribute) + val groupByExprPairs = x.groupByExprs.map(_ match { + case e: NamedExpression => (e, e.toAttribute) + case other => { + val alias = Alias(other, other.toString)() + nonAttributeGroupByExpressions += alias // add the non-attributes expression alias + (other, alias.toAttribute) + } + }) + + // substitute the non-attribute expressions for aggregations. + val aggregation = x.aggregations.map(expr => expr.transformDown { + case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) + }.asInstanceOf[NamedExpression]) + + // substitute the group by expressions. + val newGroupByExprs = groupByExprPairs.map(_._2) + + val child = if (nonAttributeGroupByExpressions.length > 0) { + // insert additional projection if contains the + // non-attribute expressions in the GROUP BY keys + Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) + } else { + x.child + } + Aggregate( - x.groupByExprs :+ VirtualColumn.groupingIdAttribute, - x.aggregations, - Expand(x.bitmasks, x.groupByExprs, gid, x.child)) + newGroupByExprs :+ VirtualColumn.groupingIdAttribute, + aggregation, + Expand(x.bitmasks, newGroupByExprs, gid, child)) } } diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 new file mode 100644 index 0000000000000..dac1b84b916d7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 new file mode 100644 index 0000000000000..c7cb747c0a659 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c new file mode 100644 index 0000000000000..c7cb747c0a659 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c @@ -0,0 +1,10 @@ +1 NULL -3 2 +1 NULL -1 2 +1 NULL 3 2 +1 NULL 4 2 +1 NULL 5 2 +1 NULL 6 2 +1 NULL 12 2 +1 NULL 14 2 +1 NULL 15 2 +1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a new file mode 100644 index 0000000000000..dac1b84b916d7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a @@ -0,0 +1,6 @@ +500 NULL 0 +91 0 1 +84 1 1 +105 2 1 +113 3 1 +107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 new file mode 100644 index 0000000000000..1eea4a9b23687 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce new file mode 100644 index 0000000000000..1eea4a9b23687 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce @@ -0,0 +1,10 @@ +1 0 5 3 +1 0 15 3 +1 0 25 3 +1 0 60 3 +1 0 75 3 +1 0 80 3 +1 0 100 3 +1 0 140 3 +1 0 145 3 +1 0 150 3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 991da2f829ae5..11a843becce69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -85,6 +85,60 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM src group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for Rollup #3", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #1", + """ + SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for CUBE #2", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + + createQueryTest("SPARK-8976 Wrong Result for GroupingSet", + """ + SELECT + count(*) AS cnt, + key % 5 as k1, + key-5 as k2, + GROUPING__ID as k3 + FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 + GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin) + createQueryTest("insert table with generator with column name", """ | CREATE TABLE gen_tmp (key Int); From 0a795336df20c7ec969366e613286f0c060a4eeb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 23:36:57 -0700 Subject: [PATCH 40/46] [SPARK-8807] [SPARKR] Add between operator in SparkR JIRA: https://issues.apache.org/jira/browse/SPARK-8807 Add between operator in SparkR. Author: Liang-Chi Hsieh Closes #7356 from viirya/add_r_between and squashes the following commits: 7f51b44 [Liang-Chi Hsieh] Add test for non-numeric column. c6a25c5 [Liang-Chi Hsieh] Add between function. --- R/pkg/NAMESPACE | 1 + R/pkg/R/column.R | 17 +++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/test_sparkSQL.R | 12 ++++++++++++ 4 files changed, 34 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f857222452d4..331307c2077a5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -77,6 +77,7 @@ exportMethods("abs", "atan", "atan2", "avg", + "between", "cast", "cbrt", "ceiling", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 8e4b0f5bf1c4d..2892e1416cc65 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -187,6 +187,23 @@ setMethod("substr", signature(x = "Column"), column(jc) }) +#' between +#' +#' Test if the column is between the lower bound and upper bound, inclusive. +#' +#' @rdname column +#' +#' @param bounds lower and upper bounds +setMethod("between", signature(x = "Column"), + function(x, bounds) { + if (is.vector(bounds) && length(bounds) == 2) { + jc <- callJMethod(x@jc, "between", bounds[1], bounds[2]) + column(jc) + } else { + stop("bounds should be a vector of lower and upper bounds") + } + }) + #' Casts the column to a different data type. #' #' @rdname column diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fad9d71158c51..ebe6fbd97ce86 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -567,6 +567,10 @@ setGeneric("asc", function(x) { standardGeneric("asc") }) #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) +#' @rdname column +#' @export +setGeneric("between", function(x, bounds) { standardGeneric("between") }) + #' @rdname column #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 76f74f80834a9..cdfe6481f60ea 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -638,6 +638,18 @@ test_that("column functions", { c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) c9 <- toDegrees(c) + toRadians(c) + + df <- jsonFile(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) }) test_that("column binary mathfunctions", { From 011551620faa87107a787530f074af3d9be7e695 Mon Sep 17 00:00:00 2001 From: Daniel Darabos Date: Thu, 16 Jul 2015 08:16:54 +0100 Subject: [PATCH 41/46] [SPARK-8893] Add runtime checks against non-positive number of partitions https://issues.apache.org/jira/browse/SPARK-8893 > What does `sc.parallelize(1 to 3).repartition(p).collect` return? I would expect `Array(1, 2, 3)` regardless of `p`. But if `p` < 1, it returns `Array()`. I think instead it should throw an `IllegalArgumentException`. > I think the case is pretty clear for `p` < 0. But the behavior for `p` = 0 is also error prone. In fact that's how I found this strange behavior. I used `rdd.repartition(a/b)` with positive `a` and `b`, but `a/b` was rounded down to zero and the results surprised me. I'd prefer an exception instead of unexpected (corrupt) results. Author: Daniel Darabos Closes #7285 from darabos/patch-1 and squashes the following commits: decba82 [Daniel Darabos] Allow repartitioning empty RDDs to zero partitions. 97de852 [Daniel Darabos] Allow zero partition count in HashPartitioner f6ba5fb [Daniel Darabos] Use require() for simpler syntax. d5e3df8 [Daniel Darabos] Require positive number of partitions in HashPartitioner 897c628 [Daniel Darabos] Require positive maxPartitions in CoalescedRDD --- core/src/main/scala/org/apache/spark/Partitioner.scala | 2 ++ core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 82889bcd30988..ad68512dccb79 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -76,6 +76,8 @@ object Partitioner { * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { + require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + def numPartitions: Int = partitions def getPartition(key: Any): Int = key match { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 663eebb8e4191..90d9735cb3f69 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -69,7 +69,7 @@ private[spark] case class CoalescedRDDPartition( * the preferred location of each new partition overlaps with as many preferred locations of its * parent partitions * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD + * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive) * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ private[spark] class CoalescedRDD[T: ClassTag]( @@ -78,6 +78,9 @@ private[spark] class CoalescedRDD[T: ClassTag]( balanceSlack: Double = 0.10) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies + require(maxPartitions > 0 || maxPartitions == prev.partitions.length, + s"Number of partitions ($maxPartitions) must be positive.") + override def getPartitions: Array[Partition] = { val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) From 4ea6480a3ba4ca7e09089c9b99d4a855894b9015 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 16 Jul 2015 08:26:39 -0700 Subject: [PATCH 42/46] [SPARK-8995] [SQL] cast date strings like '2015-01-01 12:15:31' to date Jira https://issues.apache.org/jira/browse/SPARK-8995 In PR #6981we noticed that we cannot cast date strings that contains a time, like '2015-03-18 12:39:40' to date. Besides it's not possible to cast a string like '18:03:20' to a timestamp. If a time is passed without a date, today is inferred as date. Author: Tarek Auel Author: Tarek Auel Closes #7353 from tarekauel/SPARK-8995 and squashes the following commits: 14f333b [Tarek Auel] [SPARK-8995] added tests for daylight saving time ca1ae69 [Tarek Auel] [SPARK-8995] style fix d20b8b4 [Tarek Auel] [SPARK-8995] bug fix: distinguish between 0 and null ef05753 [Tarek Auel] [SPARK-8995] added check for year >= 1000 01c9ff3 [Tarek Auel] [SPARK-8995] support for time strings 34ec573 [Tarek Auel] fixed style 71622c0 [Tarek Auel] improved timestamp and date parsing 0e30c0a [Tarek Auel] Hive compatibility cfbaed7 [Tarek Auel] fixed wrong checks 71f89c1 [Tarek Auel] [SPARK-8995] minor style fix f7452fa [Tarek Auel] [SPARK-8995] removed old timestamp parsing 30e5aec [Tarek Auel] [SPARK-8995] date and timestamp cast c1083fb [Tarek Auel] [SPARK-8995] cast date strings like '2015-01-01 12:15:31' to date or timestamp --- .../spark/sql/catalyst/expressions/Cast.scala | 17 +- .../sql/catalyst/util/DateTimeUtils.scala | 198 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 144 ++++++++++++ .../catalyst/util/DateTimeUtilsSuite.scala | 218 ++++++++++++++++++ 4 files changed, 562 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ab02addfb4d25..83d5b3b76b0a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -167,17 +167,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, utfs => { - // Throw away extra if more than 9 decimal places - val s = utfs.toString - val periodIdx = s.indexOf(".") - var n = s - if (periodIdx != -1 && n.length() - periodIdx > 9) { - n = n.substring(0, periodIdx + 10) - } - try DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(n)) - catch { case _: java.lang.IllegalArgumentException => null } - }) + buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs).orNull) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => @@ -220,10 +210,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => - try DateTimeUtils.fromJavaDate(Date.valueOf(s.toString)) - catch { case _: java.lang.IllegalArgumentException => null } - ) + buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index c1ddee3ef0230..53c32a0a9802b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -21,6 +21,8 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{Calendar, TimeZone} +import org.apache.spark.unsafe.types.UTF8String + /** * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of @@ -180,4 +182,200 @@ object DateTimeUtils { val nanos = (us % MICROS_PER_SECOND) * 1000L (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) } + + /** + * Parses a given UTF8 date string to the corresponding a corresponding [[Long]] value. + * The return type is [[Option]] in order to distinguish between 0L and null. The following + * formats are allowed: + * + * `yyyy` + * `yyyy-[m]m` + * `yyyy-[m]m-[d]d` + * `yyyy-[m]m-[d]d ` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `yyyy-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `yyyy-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]Z` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` + */ + def stringToTimestamp(s: UTF8String): Option[Long] = { + if (s == null) { + return None + } + var timeZone: Option[Byte] = None + val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0) + var i = 0 + var currentSegmentValue = 0 + val bytes = s.getBytes + var j = 0 + var digitsMilli = 0 + var justTime = false + while (j < bytes.length) { + val b = bytes(j) + val parsedValue = b - '0'.toByte + if (parsedValue < 0 || parsedValue > 9) { + if (j == 0 && b == 'T') { + justTime = true + i += 3 + } else if (i < 2) { + if (b == '-') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else if (i == 0 && b == ':') { + justTime = true + segments(3) = currentSegmentValue + currentSegmentValue = 0 + i = 4 + } else { + return None + } + } else if (i == 2) { + if (b == ' ' || b == 'T') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + } else if (i == 3 || i == 4) { + if (b == ':') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + } else if (i == 5 || i == 6) { + if (b == 'Z') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + timeZone = Some(43) + } else if (b == '-' || b == '+') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + timeZone = Some(b) + } else if (b == '.' && i == 5) { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + if (i == 6 && b != '.') { + i += 1 + } + } else { + if (b == ':' || b == ' ') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + return None + } + } + } else { + if (i == 6) { + digitsMilli += 1 + } + currentSegmentValue = currentSegmentValue * 10 + parsedValue + } + j += 1 + } + + segments(i) = currentSegmentValue + + while (digitsMilli < 6) { + segments(6) *= 10 + digitsMilli += 1 + } + + if (!justTime && (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || + segments(1) > 12 || segments(2) < 1 || segments(2) > 31)) { + return None + } + + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || + segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || + segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { + return None + } + + val c = if (timeZone.isEmpty) { + Calendar.getInstance() + } else { + Calendar.getInstance( + TimeZone.getTimeZone(f"GMT${timeZone.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) + } + + if (justTime) { + c.set(Calendar.HOUR, segments(3)) + c.set(Calendar.MINUTE, segments(4)) + c.set(Calendar.SECOND, segments(5)) + } else { + c.set(segments(0), segments(1) - 1, segments(2), segments(3), segments(4), segments(5)) + } + + Some(c.getTimeInMillis / 1000 * 1000000 + segments(6)) + } + + /** + * Parses a given UTF8 date string to the corresponding a corresponding [[Int]] value. + * The return type is [[Option]] in order to distinguish between 0 and null. The following + * formats are allowed: + * + * `yyyy`, + * `yyyy-[m]m` + * `yyyy-[m]m-[d]d` + * `yyyy-[m]m-[d]d ` + * `yyyy-[m]m-[d]d *` + * `yyyy-[m]m-[d]dT*` + */ + def stringToDate(s: UTF8String): Option[Int] = { + if (s == null) { + return None + } + val segments: Array[Int] = Array[Int](1, 1, 1) + var i = 0 + var currentSegmentValue = 0 + val bytes = s.getBytes + var j = 0 + while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) { + val b = bytes(j) + if (i < 2 && b == '-') { + segments(i) = currentSegmentValue + currentSegmentValue = 0 + i += 1 + } else { + val parsedValue = b - '0'.toByte + if (parsedValue < 0 || parsedValue > 9) { + return None + } else { + currentSegmentValue = currentSegmentValue * 10 + parsedValue + } + } + j += 1 + } + segments(i) = currentSegmentValue + if (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || + segments(2) < 1 || segments(2) > 31) { + return None + } + val c = Calendar.getInstance() + c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) + Some((c.getTimeInMillis / 1000 / 3600 / 24).toInt) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1de161c367a1d..ef8bcd41f7280 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} +import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -41,6 +42,137 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(v, Literal(expected).dataType), expected) } + test("cast string to date") { + var c = Calendar.getInstance() + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015"), DateType), new Date(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03"), DateType), new Date(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18"), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 "), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 123142"), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T123123"), DateType), new Date(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T"), DateType), new Date(c.getTimeInMillis)) + + checkEvaluation(Cast(Literal("2015-03-18X"), DateType), null) + checkEvaluation(Cast(Literal("2015/03/18"), DateType), null) + checkEvaluation(Cast(Literal("2015.03.18"), DateType), null) + checkEvaluation(Cast(Literal("20150318"), DateType), null) + checkEvaluation(Cast(Literal("2015-031-8"), DateType), null) + } + + test("cast string to timestamp") { + checkEvaluation(Cast(Literal("123"), TimestampType), + null) + + var c = Calendar.getInstance() + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015"), TimestampType), + new Timestamp(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03"), TimestampType), + new Timestamp(c.getTimeInMillis)) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 "), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17-1:0"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17-01:00"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17+07:30"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17+7:3"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17.123"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.456Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18 12:03:17.456Z"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-1:0"), TimestampType), + new Timestamp(c.getTimeInMillis)) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-01:00"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+07:30"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+7:3"), TimestampType), + new Timestamp(c.getTimeInMillis)) + + checkEvaluation(Cast(Literal("2015-03-18 123142"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-03-18T123123"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-03-18X"), TimestampType), null) + checkEvaluation(Cast(Literal("2015/03/18"), TimestampType), null) + checkEvaluation(Cast(Literal("2015.03.18"), TimestampType), null) + checkEvaluation(Cast(Literal("20150318"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-031-8"), TimestampType), null) + checkEvaluation(Cast(Literal("2015-03-18T12:03:17-0:70"), TimestampType), null) + } + test("cast from int") { checkCast(0, false) checkCast(1, true) @@ -149,6 +281,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val nts = sts + ".1" val ts = Timestamp.valueOf(nts) + val defaultTimeZone = TimeZone.getDefault + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + var c = Calendar.getInstance() + c.set(2015, 2, 8, 2, 30, 0) + checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), + c.getTimeInMillis * 1000) + c = Calendar.getInstance() + c.set(2015, 10, 1, 2, 30, 0) + checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), + c.getTimeInMillis * 1000) + TimeZone.setDefault(defaultTimeZone) + checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", DecimalType.Unlimited), null) checkEvaluation(cast("abdef", TimestampType), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index f63ac191e7366..c65fcbc4d1bc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { @@ -86,4 +88,220 @@ class DateTimeUtilsSuite extends SparkFunSuite { checkFromToJavaDate(new Date(df1.parse("1776-07-04 10:30:00").getTime)) checkFromToJavaDate(new Date(df2.parse("1776-07-04 18:30:00 UTC").getTime)) } + + test("string to date") { + val millisPerDay = 1000L * 3600L * 24L + var c = Calendar.getInstance() + c.set(2015, 0, 28, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === + c.getTimeInMillis / millisPerDay) + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === + c.getTimeInMillis / millisPerDay) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === + c.getTimeInMillis / millisPerDay) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === + c.getTimeInMillis / millisPerDay) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === + c.getTimeInMillis / millisPerDay) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === + c.getTimeInMillis / millisPerDay) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === + c.getTimeInMillis / millisPerDay) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === + c.getTimeInMillis / millisPerDay) + + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty) + assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + } + + test("string to timestamp") { + var c = Calendar.getInstance() + c.set(1969, 11, 31, 16, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === + c.getTimeInMillis * 1000) + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get === + c.getTimeInMillis * 1000) + c = Calendar.getInstance() + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get === + c.getTimeInMillis * 1000) + c = Calendar.getInstance() + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === + c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === + c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance() + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === + c.getTimeInMillis * 1000 + 121) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === + c.getTimeInMillis * 1000 + 120) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(Calendar.HOUR, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("T18:12:15.12312+7:30")).get === + c.getTimeInMillis * 1000 + 120) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(Calendar.HOUR, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 123) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("18:12:15.12312+7:30")).get === + c.getTimeInMillis * 1000 + 120) + + c = Calendar.getInstance() + c.set(2011, 4, 6, 7, 8, 9) + c.set(Calendar.MILLISECOND, 100) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) + + val defaultTimeZone = TimeZone.getDefault + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + + c = Calendar.getInstance() + c.set(2015, 2, 8, 2, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-3-8 2:0:0")).get === c.getTimeInMillis * 1000) + c.add(Calendar.MINUTE, 30) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-3-8 3:30:0")).get === c.getTimeInMillis * 1000) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-3-8 2:30:0")).get === c.getTimeInMillis * 1000) + + c = Calendar.getInstance() + c.set(2015, 10, 1, 1, 59, 0) + c.set(Calendar.MILLISECOND, 0) + c.add(Calendar.MINUTE, 31) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-11-1 2:30:0")).get === c.getTimeInMillis * 1000) + TimeZone.setDefault(defaultTimeZone) + + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) + assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) + assert(DateTimeUtils.stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + } } From b536d5dc6c2c712270b8130ddd9945dff19a27d9 Mon Sep 17 00:00:00 2001 From: Jan Prach Date: Thu, 16 Jul 2015 18:42:41 +0100 Subject: [PATCH 43/46] [SPARK-9015] [BUILD] Clean project import in scala ide Cleanup maven for a clean import in scala-ide / eclipse. * remove groovy plugin which is really not needed at all * add-source from build-helper-maven-plugin is not needed as recent version of scala-maven-plugin do it automatically * add lifecycle-mapping plugin to hide a few useless warnings from ide Author: Jan Prach Closes #7375 from jendap/clean-project-import-in-scala-ide and squashes the following commits: c4b4c0f [Jan Prach] fix whitespaces 5a83e07 [Jan Prach] Revert "remove java compiler warnings from java tests" 312007e [Jan Prach] scala-maven-plugin itself add scala sources by default f47d856 [Jan Prach] remove spark-1.4-staging repository c8a54db [Jan Prach] remove java compiler warnings from java tests 999a068 [Jan Prach] remove some maven warnings in scala ide 80fbdc5 [Jan Prach] remove groovy and gmavenplus plugin --- pom.xml | 130 +++++++++++++++++++---------------------------- repl/pom.xml | 2 - sql/core/pom.xml | 1 - sql/hive/pom.xml | 1 - tools/pom.xml | 4 -- 5 files changed, 53 insertions(+), 85 deletions(-) diff --git a/pom.xml b/pom.xml index aa49e2ab7294b..c5c655834bdeb 100644 --- a/pom.xml +++ b/pom.xml @@ -152,7 +152,6 @@ 1.2.1 4.3.2 3.4.1 - ${project.build.directory}/spark-test-classpath.txt 2.10.4 2.10 ${scala.version} @@ -283,18 +282,6 @@ false - - - spark-1.4-staging - Spark 1.4 RC4 Staging Repository - https://repository.apache.org/content/repositories/orgapachespark-1112 - - true - - - false - - @@ -318,17 +305,6 @@ unused 1.0.0 - - - org.codehaus.groovy - groovy-all - 2.3.7 - provided - + + + org.eclipse.m2e + lifecycle-mapping + 1.0.0 + + + + + + org.apache.maven.plugins + maven-dependency-plugin + [2.8,) + + build-classpath + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + [2.6,) + + test-jar + + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + [1.8,) + + run + + + + + + + + + + @@ -1429,34 +1457,12 @@ test - ${test_classpath_file} + test_classpath - - - org.codehaus.gmavenplus - gmavenplus-plugin - 1.5 - - - process-test-classes - - execute - - - - - - - - -