diff --git a/bin/pyspark b/bin/pyspark index 6655725ef8e8e..96f30a260a09e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -50,22 +50,47 @@ fi . "$FWDIR"/bin/load-spark-env.sh -# Figure out which Python executable to use +# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` +# executable, while the worker would still be launched using PYSPARK_PYTHON. +# +# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added +# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver. +# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set +# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver +# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython +# and executor Python executables. +# +# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables. + +# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set: +if hash python2.7 2>/dev/null; then + # Attempt to use Python 2.7, if installed: + DEFAULT_PYTHON="python2.7" +else + DEFAULT_PYTHON="python" +fi + +# Determine the Python executable to use for the driver: +if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then + # If IPython options are specified, assume user wants to run IPython + # (for backwards-compatibility) + PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS" + PYSPARK_DRIVER_PYTHON="ipython" +elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" +fi + +# Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON="ipython" + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then + echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 + exit 1 else - PYSPARK_PYTHON="python" + PYSPARK_PYTHON="$DEFAULT_PYTHON" fi fi export PYSPARK_PYTHON -if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" -fi - # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -93,9 +118,9 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_PYTHON" -m doctest $1 + exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else - exec "$PYSPARK_PYTHON" $1 + exec "$PYSPARK_DRIVER_PYTHON" $1 fi exit fi @@ -111,5 +136,5 @@ if [[ "$1" =~ \.py$ ]]; then else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 - exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS + exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 8ca731038e528..e72826dc25f41 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -26,6 +26,8 @@ import scala.collection.JavaConversions._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import com.google.common.io.Files +import org.apache.spark.util.Utils + /** * Utilities for tests. Included in main codebase since it's used by multiple * projects. @@ -42,8 +44,7 @@ private[spark] object TestUtils { * in order to avoid interference between tests. */ def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) createJar(files, jarFile) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 71bdf0fe1b917..e314408c067e9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") val worker = pb.start() // Redirect worker stdout and stderr @@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 79b4d7ea41a33..af94b05ce3847 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -34,7 +34,8 @@ object PythonRunner { val pythonFile = args(0) val pyFiles = args(1) val otherArgs = args.slice(2, args.length) - val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python")) // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) @@ -57,6 +58,7 @@ object PythonRunner { val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index a4409181ec907..4c9ca97a2a6b7 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -66,13 +66,27 @@ sealed abstract class ManagedBuffer { final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + override def size: Long = length override def nioByteBuffer(): ByteBuffer = { var channel: FileChannel = null try { channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + val buf = ByteBuffer.allocate(length.toInt) + channel.read(buf, offset) + buf.flip() + buf + } else { + channel.map(MapMode.READ_ONLY, offset, length) + } } catch { case e: IOException => Try(channel.size).toOption match { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..4f6f5e235811d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,14 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList import org.apache.spark._ +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +54,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +584,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..6b00190c5eccc 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,8 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,14 +53,23 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } + + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } @@ -72,14 +83,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -153,17 +182,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +223,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +256,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +295,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +497,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +526,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +541,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +559,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -651,7 +708,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -770,6 +827,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +845,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -807,9 +893,11 @@ private[nio] class ConnectionManager( override def run(): Unit = { messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } }) } } @@ -817,15 +905,23 @@ private[nio] class ConnectionManager( val status = new MessageStatus(message, connectionManagerId, s => { timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val e = new IOException( + "sendMessageReliably failed with ACK that signalled a remote error") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } }) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f0006b42aee4f..32e6b15bb0999 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.text.SimpleDateFormat import java.util.{Locale, Date} import scala.xml.Node + import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ @@ -169,6 +170,7 @@ private[spark] object UIUtils extends Logging { refreshInterval: Option[Int] = None): Seq[Node] = { val appName = activeTab.appName + val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • {tab.name} @@ -187,7 +189,9 @@ private[spark] object UIUtils extends Logging { - +
    @@ -216,8 +220,10 @@ private[spark] object UIUtils extends Logging {

    - + + + {title}

    diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index db01be596e073..2414e4c65237e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -103,7 +103,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val taskHeaders: Seq[String] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", + "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host", "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } {info.status} {info.taskLocality} - {info.host} + {info.executorId} / {info.host} {UIUtils.formatDate(new Date(info.launchTime))} {formatDuration} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3d307b3c16d3e..07477dd460a4b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -168,6 +168,20 @@ private[spark] object Utils extends Logging { private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + // Add a shutdown hook to delete the temp dirs when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { + override def run(): Unit = Utils.logUncaughtExceptions { + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + }) + // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -252,14 +266,6 @@ private[spark] object Utils extends Logging { } registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) dir } @@ -666,15 +672,30 @@ private[spark] object Utils extends Logging { */ def deleteRecursively(file: File) { if (file != null) { - if (file.isDirectory() && !isSymlink(file)) { - for (child <- listFilesSafely(file)) { - deleteRecursively(child) + try { + if (file.isDirectory && !isSymlink(file)) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(file.getAbsolutePath) + } } - } - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } } } } @@ -713,7 +734,7 @@ private[spark] object Utils extends Logging { */ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { if (!dir.isDirectory) { - throw new IllegalArgumentException("$dir is not a directory!") + throw new IllegalArgumentException(s"$dir is not a directory!") } val filesAndDirs = dir.listFiles() val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 7e18f45de7b5b..a8867020e457d 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.jar.{JarEntry, JarOutputStream} -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -41,8 +40,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { override def beforeAll() { super.beforeAll() - tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() + tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4a53d25012ad9..a2b74c4419d46 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, FileWriter} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} @@ -39,8 +38,7 @@ class FileSuite extends FunSuite with LocalSparkContext { override def beforeEach() { super.beforeEach() - tempDir = Files.createTempDir() - tempDir.deleteOnExit() + tempDir = Utils.createTempDir() } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4cba90e8f2afe..1cdf50d5c08c7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite import org.scalatest.Matchers -import com.google.common.io.Files class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { @@ -332,7 +331,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { } def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index d5ebfb3f3fae1..12d1c7b2faba6 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,8 +23,6 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import com.google.common.io.Files - import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite @@ -66,9 +64,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { * 3) Does the contents be the same. */ test("Correctness of WholeTextFileRecordReader.") { - - val dir = Files.createTempDir() - dir.deleteOnExit() + val dir = Utils.createTempDir() println(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 75b01191901b8..3620e251cc139 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -24,13 +24,14 @@ import org.apache.hadoop.util.Progressable import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random -import com.google.common.io.Files import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.spark.{Partitioner, SharedSparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + import org.scalatest.FunSuite class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { @@ -381,14 +382,16 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - emptyDir.deleteOnExit() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - emptyDir.delete() + val emptyDir = Utils.createTempDir() + try { + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.isEmpty) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } finally { + Utils.deleteRecursively(emptyDir) + } } test("keys and values") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 3efa85431876b..abc300fcffaf9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import scala.collection.mutable import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -51,8 +50,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { private var logDirPath: Path = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "spark-events") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 48114feee6233..e05f373392d4a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} -import com.google.common.io.Files import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -39,8 +38,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { private var testDir: File = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() } after { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index e4522e00a622d..bc5c74c126b74 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,22 +19,13 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.shuffle.hash.HashShuffleManager - -import scala.collection.mutable import scala.language.reflectiveCalls -import akka.actor.Props -import com.google.common.io.Files import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -48,10 +39,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before override def beforeAll() { super.beforeAll() - rootDir0 = Files.createTempDir() - rootDir0.deleteOnExit() - rootDir1 = Files.createTempDir() - rootDir1.deleteOnExit() + rootDir0 = Utils.createTempDir() + rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index c3dd156b40514..dc2a05631d83d 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, IOException} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfter, FunSuite} @@ -44,7 +43,7 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { private var logDirPathString: String = _ before { - testDir = Files.createTempDir() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "test-file-logger") logDirPathString = logDirPath.toString } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index e63d9d085e385..0344da60dae66 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -112,7 +112,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() + val tmpDir2 = Utils.createTempDir() tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) @@ -141,7 +141,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes across multiple files") { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) Files.write("0123456789", files(0), Charsets.UTF_8) @@ -308,4 +308,28 @@ class UtilsSuite extends FunSuite { } } + test("deleteRecursively") { + val tempDir1 = Utils.createTempDir() + assert(tempDir1.exists()) + Utils.deleteRecursively(tempDir1) + assert(!tempDir1.exists()) + + val tempDir2 = Utils.createTempDir() + val tempFile1 = new File(tempDir2, "foo.txt") + Files.touch(tempFile1) + assert(tempFile1.exists()) + Utils.deleteRecursively(tempFile1) + assert(!tempFile1.exists()) + + val tempDir3 = new File(tempDir2, "subdir") + assert(tempDir3.mkdir()) + val tempFile2 = new File(tempDir3, "bar.txt") + Files.touch(tempFile2) + assert(tempFile2.exists()) + Utils.deleteRecursively(tempDir2) + assert(!tempDir2.exists()) + assert(!tempDir3.exists()) + assert(!tempFile2.exists()) + } + } diff --git a/dev/run-tests b/dev/run-tests index 4be2baaf48cd1..f47fcf66ff7e7 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -42,7 +42,7 @@ function handle_error () { elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi diff --git a/dev/scalastyle b/dev/scalastyle index efb5f291ea3b7..c3b356bcb3c06 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -26,6 +26,8 @@ echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalasty >> scalastyle.txt ERRORS=$(cat scalastyle.txt | grep -e "\") +rm scalastyle.txt + if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 diff --git a/docs/README.md b/docs/README.md index 79708c3df9106..0facecdd5f767 100644 --- a/docs/README.md +++ b/docs/README.md @@ -54,19 +54,19 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} -## API Docs (Scaladoc and Epydoc) +## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the -SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as +Similarly, you can build just the PySpark docs by running `make html` from the +SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs using [epydoc](http://epydoc.sourceforge.net/). +PySpark docs [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index 7bc3a78e2d265..f4bf242ac191b 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -8,6 +8,9 @@ gems: kramdown: entity_output: numeric +include: + - _static + # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. SPARK_VERSION: 1.0.0-SNAPSHOT diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3b02e090aec28..4566a2fff562b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,19 +63,20 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) - # Build Epydoc for Python - puts "Moving to python directory and building epydoc." - cd("../python") - puts `epydoc --config epydoc.conf` + # Build Sphinx docs for Python - puts "Moving back into docs dir." - cd("../docs") + puts "Moving to python/docs directory and building sphinx." + cd("../python/docs") + puts `make html` + + puts "Moving back into home dir." + cd("../../") puts "Making directory api/python" - mkdir_p "api/python" + mkdir_p "docs/api/python" - puts "cp -r ../python/docs/. api/python" - cp_r("../python/docs/.", "api/python") + puts "cp -r python/docs/_build/html/. docs/api/python" + cp_r("python/docs/_build/html/.", "docs/api/python") cd("..") end diff --git a/docs/configuration.md b/docs/configuration.md index 1c33855365170..f311f0d2a6206 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -103,6 +103,14 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.memory + 512m + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + + spark.serializer org.apache.spark.serializer.
    JavaSerializer diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 8e8cc1dd983f8..18420afb27e3c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: +use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ PYSPARK_PYTHON=ipython ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %}
    diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 941dfb988b9fb..0d6b82b4944f3 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -32,6 +32,7 @@ import tempfile import time import urllib2 +import warnings from optparse import OptionParser from sys import stderr import boto @@ -61,8 +62,8 @@ def parse_args(): "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") parser.add_option( - "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: %default)") + "-w", "--wait", type="int", + help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -195,18 +196,6 @@ def get_or_make_group(conn, name): return conn.create_security_group(name, "Spark EC2 group") -# Wait for a set of launched instances to exit the "pending" state -# (i.e. either to start running or to fail and be terminated) -def wait_for_instances(conn, instances): - while True: - for i in instances: - i.update() - if len([i for i in instances if i.state == 'pending']) > 0: - time.sleep(5) - else: - return - - # Check whether a given EC2 instance object is in a state we consider active, # i.e. not terminating or terminated. We count both stopping and stopped as # active since we can restart stopped clusters. @@ -594,7 +583,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3") + ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) @@ -619,14 +608,64 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up -def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): - print "Waiting for instances to start up..." - time.sleep(5) - wait_for_instances(conn, master_nodes) - wait_for_instances(conn, slave_nodes) - print "Waiting %d more seconds..." % wait_secs - time.sleep(wait_secs) +def is_ssh_available(host, opts): + "Checks if SSH is available on the host." + try: + with open(os.devnull, 'w') as devnull: + ret = subprocess.check_call( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=devnull, + stderr=devnull + ) + return ret == 0 + except subprocess.CalledProcessError as e: + return False + + +def is_cluster_ssh_available(cluster_instances, opts): + for i in cluster_instances: + if not is_ssh_available(host=i.ip_address, opts=opts): + return False + else: + return True + + +def wait_for_cluster_state(cluster_instances, cluster_state, opts): + """ + cluster_instances: a list of boto.ec2.instance.Instance + cluster_state: a string representing the desired state of all the instances in the cluster + value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as + 'running', 'terminated', etc. + (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) + """ + sys.stdout.write( + "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state) + ) + sys.stdout.flush() + + num_attempts = 0 + + while True: + time.sleep(3 * num_attempts) + + for i in cluster_instances: + s = i.update() # capture output to suppress print to screen in newer versions of boto + + if cluster_state == 'ssh-ready': + if all(i.state == 'running' for i in cluster_instances) and \ + is_cluster_ssh_available(cluster_instances, opts): + break + else: + if all(i.state == cluster_state for i in cluster_instances): + break + + num_attempts += 1 + + sys.stdout.write(".") + sys.stdout.flush() + + sys.stdout.write("\n") # Get number of local disks available for a given EC2 instance type. @@ -868,6 +907,16 @@ def real_main(): (opts, action, cluster_name) = parse_args() # Input parameter validation + if opts.wait is not None: + # NOTE: DeprecationWarnings are silent in 2.7+ by default. + # To show them, run Python with the -Wdefault switch. + # See: https://docs.python.org/3.5/whatsnew/2.7.html + warnings.warn( + "This option is deprecated and has no effect. " + "spark-ec2 automatically waits as long as necessary for clusters to startup.", + DeprecationWarning + ) + if opts.ebs_vol_num > 8: print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) @@ -890,7 +939,11 @@ def real_main(): (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) else: (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": @@ -919,7 +972,11 @@ def real_main(): else: group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] - + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='terminated', + opts=opts + ) attempt = 1 while attempt <= 3: print "Attempt %d" % attempt @@ -1019,7 +1076,11 @@ def real_main(): for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala new file mode 100644 index 0000000000000..ae6057758d6fc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scala.reflect.runtime.universe._ + +/** + * Abstract class for parameter case classes. + * This overrides the [[toString]] method to print all case class fields by name and value. + * @tparam T Concrete parameter class. + */ +abstract class AbstractParams[T: TypeTag] { + + private def tag: TypeTag[T] = typeTag[T] + + /** + * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: + * { + * [field name]:\t[field value]\n + * [field name]:\t[field value]\n + * ... + * } + */ + override def toString: String = { + val tpe = tag.tpe + val allAccessors = tpe.declarations.collect { + case m: MethodSymbol if m.isCaseAccessor => m + } + val mirror = runtimeMirror(getClass.getClassLoader) + val instanceMirror = mirror.reflect(this) + allAccessors.map { f => + val paramName = f.name.toString + val fieldMirror = instanceMirror.reflectField(f) + val paramValue = fieldMirror.get + s" $paramName:\t$paramValue" + }.mkString("{\n", ",\n", "\n}") + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a6f78d2441db1..1edd2432a0352 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index d6b2fe430e5a4..e49129c4e7844 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} object Correlations { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala new file mode 100644 index 0000000000000..cb1abbd18fd4d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Compute the similar columns of a matrix, using cosine similarity. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + * + * Example invocation: + * + * bin/run-example mllib.CosineSimilarity \ + * --threshold 0.1 data/mllib/sample_svm_data.txt + */ +object CosineSimilarity { + case class Params(inputFile: String = null, threshold: Double = 0.1) + extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("CosineSimilarity") { + head("CosineSimilarity: an example app.") + opt[Double]("threshold") + .required() + .text(s"threshold similarity: to tradeoff computation vs quality estimate") + .action((x, c) => c.copy(threshold = x)) + arg[String]("") + .required() + .text(s"input file, one row per line, space-separated") + .action((x, c) => c.copy(inputFile = x)) + note( + """ + |For example, the following command runs this app on a dataset: + | + | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \ + | examplesjar.jar \ + | --threshold 0.1 data/mllib/sample_svm_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("CosineSimilarity") + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(params.inputFile).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + }.cache() + val mat = new RowMatrix(rows) + + // Compute similar columns perfectly, with brute force. + val exact = mat.columnSimilarities() + + // Compute similar columns with estimation using DIMSUM + val approx = mat.columnSimilarities(params.threshold) + + val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) } + val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) } + val MAE = exactEntries.leftOuterJoin(approxEntries).values.map { + case (u, Some(v)) => + math.abs(u - v) + case (u, None) => + math.abs(u) + }.mean() + + println(s"Average absolute error in estimate is: $MAE") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4adc91d2fbe65..837d0591478c5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -62,7 +62,7 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) + fracTest: Double = 0.2) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -138,9 +138,11 @@ object DecisionTreeRunner { def run(params: Params) { - val conf = new SparkConf().setAppName("DecisionTreeRunner") + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") val sc = new SparkContext(conf) + println(s"DecisionTreeRunner with parameters:\n$params") + // Load training data and cache it. val origExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() @@ -235,7 +237,10 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) if (params.numTrees == 1) { + val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.numNodes < 20) { println(model.toDebugString) // Print full model. } else { @@ -259,8 +264,11 @@ object DecisionTreeRunner { } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { + val startTime = System.nanoTime() val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { @@ -275,8 +283,11 @@ object DecisionTreeRunner { println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { + val startTime = System.nanoTime() val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 89dfa26c2299c..11e35598baf50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -44,7 +44,7 @@ object DenseKMeans { input: String = null, k: Int = -1, numIterations: Int = 10, - initializationMode: InitializationMode = Parallel) + initializationMode: InitializationMode = Parallel) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 05b7d66f8dffd..e1f9622350135 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -47,7 +47,7 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 98aaedb9d7dc9..fc6678013b932 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { rank: Int = 10, numUserBlocks: Int = -1, numProductBlocks: Int = -1, - implicitPrefs: Boolean = false) + implicitPrefs: Boolean = false) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 4532512c01f84..6e4e2d07f284b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext} object MultivariateSummarizer { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index f01b8266e3fe3..663c12734af68 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._ object SampledRDDs { case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 952fa2a5109a4..f1ff4e6911f5e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -37,7 +37,7 @@ object SparseNaiveBayes { input: String = null, minPartitions: Int = 0, numFeatures: Int = -1, - lambda: Double = 1.0) + lambda: Double = 1.0) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/mllib/pom.xml b/mllib/pom.xml index a5eeef88e9d62..696e9396f627c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.9 + 0.10 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e9f41758581e3..f7251e65e04f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.Word2VecModel import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. @@ -287,6 +289,59 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib Word2Vec fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + * @param dataJRDD input JavaRDD + * @param vectorSize size of vector + * @param learningRate initial learning rate + * @param numPartitions number of partitions + * @param numIterations number of iterations + * @param seed initial seed for random generator + * @return A handle to java Word2VecModelWrapper instance at python side + */ + def trainWord2Vec( + dataJRDD: JavaRDD[java.util.ArrayList[String]], + vectorSize: Int, + learningRate: Double, + numPartitions: Int, + numIterations: Int, + seed: Long): Word2VecModelWrapper = { + val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) + val word2vec = new Word2Vec() + .setVectorSize(vectorSize) + .setLearningRate(learningRate) + .setNumPartitions(numPartitions) + .setNumIterations(numIterations) + .setSeed(seed) + val model = word2vec.fit(data) + data.unpersist() + new Word2VecModelWrapper(model) + } + + private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(words) + ret.add(similarity) + ret + } + } + /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 3afb47767281c..4734251127bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -47,7 +47,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = vector.toBreeze.norm(p) + var norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index fc1444705364a..d321994c2a651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -67,7 +67,7 @@ private case class VocabWord( class Word2Vec extends Serializable with Logging { private var vectorSize = 100 - private var startingAlpha = 0.025 + private var learningRate = 0.025 private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() @@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging { * Sets initial learning rate (default: 0.025). */ def setLearningRate(learningRate: Double): this.type = { - this.startingAlpha = learningRate + this.learningRate = learningRate this } @@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha + var alpha = learningRate for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging { lwc = wordCount // TODO: discount by iteration? alpha = - startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } wc += sentence.size @@ -437,7 +437,7 @@ class Word2VecModel private[mllib] ( * Find synonyms of a word * @param word a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b311d10023894..03eeaa707715b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging { Some(mutableNodeToFeatures.toMap) } + // array of nodes to train indexed by node index in group + val nodes = new Array[Node](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + // Calculate best splits for all nodes in the group timer.start("chooseSplits") @@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging { // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap() @@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging { // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) - node.predict = predict.predict + node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) + node.impurity = stats.impurity logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) - node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) - node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - nodeQueue.enqueue((treeIndex, node.rightNode.get)) + val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + } + logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + @@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): InformationGainStats = { + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - - val impurity = parentNodeAgg.calculate() - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) } /** @@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging { * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split * @param rightImpurityCalculator right node aggregates for a split - * @return predict value for current node + * @return predict value and impurity for current node */ - private def calculatePredict( + private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): Predict = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() - new Predict(predict, prob) + (predict, impurity) } /** @@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging { private def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + featuresForNode: Option[Array[Int]], + node: Node): (Split, InformationGainStats, Predict) = { - // calculate predict only once - var predict: Option[Predict] = None + // calculate predict and impurity if current node is top node + val level = Node.indexToLevel(node.id) + var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predict, node.impurity)) + } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = @@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - assert(predict.isDefined, "must calculate predict for each node") - - (bestSplit, bestSplitStats, predict.get) + (bestSplit, bestSplitStats, predictWithImpurity.get._1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index a89e71e115806..9a50ecb550c38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double) extends Serializable { + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" @@ -58,5 +62,6 @@ private[tree] object InformationGainStats { * denote that current split doesn't satisfies minimum info gain or * minimum number of instances per node. */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, + new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 56c3e25d9285f..2179da8dbe03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector * * @param id integer node id, from 1 * @param predict predicted value at the node - * @param isLeaf whether the leaf is a node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf * @param split split to calculate left and right nodes * @param leftNode left child * @param rightNode right child @@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - var predict: Double, + var predict: Predict, + var impurity: Double, var isLeaf: Boolean, var split: Option[Split], var leftNode: Option[Node], @@ -49,7 +51,7 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "split = " + split + ", stats = " + stats + "impurity = " + impurity + "split = " + split + ", stats = " + stats /** * build the left node and right nodes if not leaf @@ -62,6 +64,7 @@ class Node ( logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) + logDebug("impurity = " + impurity) if (!isLeaf) { leftNode = Some(nodes(Node.leftChildIndex(id))) rightNode = Some(nodes(Node.rightChildIndex(id))) @@ -77,7 +80,7 @@ class Node ( */ def predict(features: Vector) : Double = { if (isLeaf) { - predict + predict.predict } else{ if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { @@ -109,7 +112,7 @@ class Node ( } else { Some(rightNode.get.deepCopy()) } - new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) } /** @@ -154,7 +157,7 @@ class Node ( } val prefix: String = " " * indentFactor if (isLeaf) { - prefix + s"Predict: $predict\n" + prefix + s"Predict: ${predict.predict}\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + @@ -170,7 +173,27 @@ private[tree] object Node { /** * Return a node with the given node id (but nothing else set). */ - def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, + false, None, None, None, None) + + /** + * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. + * This is used in `DecisionTree.findBestSplits` to construct child nodes + * after finding the best splits for parent nodes. + * Other fields are set at next level. + * @param nodeIndex integer node id, from 1 + * @param predict predicted value at the node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf + * @return new node instance + */ + def apply( + nodeIndex: Int, + predict: Predict, + impurity: Double, + isLeaf: Boolean): Node = { + new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) + } /** * Return the index of the left child of this node. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index fb76dccfdf79e..2bf9d9816ae45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite +import breeze.linalg.{norm => brzNorm} + import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a48ed71a1c5fc..98a72b0c4d750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) assert(stats.impurity > 0.2) } @@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 0.6) + assert(rootNode.predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 0) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Second level node building with vs. without groups") { @@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats1.impurity === stats2.impurity) assert(stats1.leftImpurity === stats2.leftImpurity) assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict === children2(i).predict) + assert(children1(i).predict.predict === children2(i).predict.predict) } } @@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) } + + test("Avoid aggregation on the last level") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } } object DecisionTreeSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 8ef2bb1bf6a78..0dbe766b4d917 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString @@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) val lines = outputDir.listFiles() @@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Vectors.sparse(2, Array(1), Array(-1.0)), Vectors.dense(0.0, 1.0) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "vectors") val path = outputDir.toURI.toString vectors.saveAsTextFile(path) @@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))), LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "points") val path = outputDir.toURI.toString points.saveAsTextFile(path) diff --git a/pom.xml b/pom.xml index 7756c89b00cad..d047b9e307d4b 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -127,7 +127,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0 + 0.12.0-protobuf 1.4.3 1.2.3 8.1.14.v20131031 diff --git a/python/docs/conf.py b/python/docs/conf.py index c368cf81a003b..8e6324f058251 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.1' +version = '1.2-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '' +release = '1.2-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -102,7 +102,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = 'nature' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -121,7 +121,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "../../docs/img/spark-logo-hd.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -154,10 +154,10 @@ #html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +html_domain_indices = False # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False diff --git a/python/docs/index.rst b/python/docs/index.rst index e0f4e5c192acf..703bef644de28 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -3,7 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to PySpark API reference! +Welcome to Spark Python API Docs! =================================== Contents: @@ -25,14 +25,12 @@ Core classes: Main entry point for Spark functionality. :class:`pyspark.RDD` - + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Indices and tables ================== -* :ref:`genindex` -* :ref:`modindex` * :ref:`search` diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index e95d19e97f151..4548b8739ed91 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -20,6 +20,14 @@ pyspark.mllib.clustering module :undoc-members: :show-inheritance: +pyspark.mllib.feature module +------------------------------- + +.. automodule:: pyspark.mllib.feature + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.linalg module --------------------------- diff --git a/python/epydoc.conf b/python/epydoc.conf deleted file mode 100644 index 8593e08deda19..0000000000000 --- a/python/epydoc.conf +++ /dev/null @@ -1,38 +0,0 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Information about the project. -name: Spark 1.0.0 Python API Docs -url: http://spark.apache.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no - -exclude: pyspark.cloudpickle pyspark.worker pyspark.join - pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests - pyspark.rddsampler pyspark.daemon - pyspark.mllib.tests pyspark.shuffle diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1a2e774738fe7..e39e6514d77a1 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -20,33 +20,21 @@ Public classes: - - L{SparkContext} + - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - L{Broadcast} A broadcast variable that gets reused across tasks. - - L{Accumulator} + - L{Accumulator} An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - L{SparkConf} For configuring Spark. - - L{SparkFiles} + - L{SparkFiles} Access files shipped with jobs. - - L{StorageLevel} + - L{StorageLevel} Finer-grained cache persistence levels. -Spark SQL: - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - -Hive: - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. """ # The following block allows us to import python's random instead of mllib.random for scripts in diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b64875a3f495a..dc7cd0bce56f3 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -83,11 +83,11 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. - @param loadDefaults: whether to load values from Java system + :param loadDefaults: whether to load values from Java system properties (True by default) - @param _jvm: internal parameter used to pass a handle to the + :param _jvm: internal parameter used to pass a handle to the Java VM; does not need to be set by users - @param _jconf: Optionally pass in an existing SparkConf handle + :param _jconf: Optionally pass in an existing SparkConf handle to use its parameters """ if _jconf: @@ -139,7 +139,7 @@ def setAll(self, pairs): """ Set multiple parameters, passed as a list of key-value pairs. - @param pairs: list of key-value pairs to set + :param pairs: list of key-value pairs to set """ for (k, v) in pairs: self._jconf.set(k, v) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5b9c0311ef586..89d2e2e5b4a8e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -67,27 +67,28 @@ class SparkContext(object): _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. - @param master: Cluster URL to connect to + :param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - @param appName: A name for your job, to display on the cluster web UI. - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster + :param appName: A name for your job, to display on the cluster web UI. + :param sparkHome: Location where Spark is installed on cluster nodes. + :param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on + :param environment: A dictionary of environment variables to set on worker nodes. - @param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. - @param serializer: The serializer for RDDs. - @param conf: A L{SparkConf} object setting Spark properties. - @param gateway: Use an existing gateway and JVM, otherwise a new JVM + :param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching, 0 to automatically choose + the batch size based on object sizes, or -1 to use an unlimited + batch size + :param serializer: The serializer for RDDs. + :param conf: A L{SparkConf} object setting Spark properties. + :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. @@ -117,6 +118,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._unbatched_serializer = serializer if batchSize == 1: self.serializer = self._unbatched_serializer + elif batchSize == 0: + self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) @@ -417,16 +420,16 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, 3. If this fails, the fallback is to call 'toString' on each key and value 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side - @param path: path to sequncefile - @param keyClass: fully qualified classname of key Writable class + :param path: path to sequncefile + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: - @param valueConverter: - @param minSplits: minimum splits in dataset + :param keyConverter: + :param valueConverter: + :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) @@ -446,18 +449,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -476,17 +479,17 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -507,18 +510,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java. - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -537,17 +540,17 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index a765b1c4f7d87..cd43982191702 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -79,15 +79,15 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. :Allowed values: @@ -151,15 +151,15 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, """ Train a support vector machine on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param regParam: The regularizer parameter (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param regParam: The regularizer parameter (default: 1.0). + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regType: The type of regularizer used for training our model. :Allowed values: @@ -238,10 +238,10 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - @param data: RDD of NumPy vectors, one per element, where the first + :param data: RDD of NumPy vectors, one per element, where the first coordinate is the label and the rest is the feature vector (e.g. a count vector). - @param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter """ sc = data.context jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py new file mode 100644 index 0000000000000..a44a27fd3b6a6 --- /dev/null +++ b/python/pyspark/mllib/feature.py @@ -0,0 +1,193 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for feature in MLlib. +""" +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + +from pyspark.mllib.linalg import _convert_to_vector + +__all__ = ['Word2Vec', 'Word2VecModel'] + + +class Word2VecModel(object): + """ + class for Word2Vec model + """ + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def transform(self, word): + """ + :param word: a word + :return: vector representation of word + Transforms a word to its vector representation + + Note: local use only + """ + # TODO: make transform usable in RDD operations from python side + result = self._java_model.transform(word) + return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result))) + + def findSynonyms(self, x, num): + """ + :param x: a word or a vector representation of word + :param num: number of synonyms to find + :return: array of (word, cosineSimilarity) + Find synonyms of a word + + Note: local use only + """ + # TODO: make findSynonyms usable in RDD operations from python side + ser = PickleSerializer() + if type(x) == str: + jlist = self._java_model.findSynonyms(x, num) + else: + bytes = bytearray(ser.dumps(_convert_to_vector(x))) + vec = self._sc._jvm.SerDe.loads(bytes) + jlist = self._java_model.findSynonyms(vec, num) + words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist))) + return zip(words, similarity) + + +class Word2Vec(object): + """ + Word2Vec creates vector representation of words in a text corpus. + The algorithm first constructs a vocabulary from the corpus + and then learns vector representation of words in the vocabulary. + The vector representation can be used as features in + natural language processing and machine learning algorithms. + + We used skip-gram model in our implementation and hierarchical softmax + method to train the model. The variable names in the implementation + matches the original C implementation. + For original C implementation, see https://code.google.com/p/word2vec/ + For research papers, see + Efficient Estimation of Word Representations in Vector Space + and + Distributed Representations of Words and Phrases and their Compositionality. + + >>> sentence = "a b " * 100 + "a c " * 10 + >>> localDoc = [sentence, sentence] + >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) + >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> syms = model.findSynonyms("a", 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + >>> vec = model.transform("a") + >>> len(vec) + 10 + >>> syms = model.findSynonyms(vec, 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + """ + def __init__(self): + """ + Construct Word2Vec instance + """ + self.vectorSize = 100 + self.learningRate = 0.025 + self.numPartitions = 1 + self.numIterations = 1 + self.seed = 42L + + def setVectorSize(self, vectorSize): + """ + Sets vector size (default: 100). + """ + self.vectorSize = vectorSize + return self + + def setLearningRate(self, learningRate): + """ + Sets initial learning rate (default: 0.025). + """ + self.learningRate = learningRate + return self + + def setNumPartitions(self, numPartitions): + """ + Sets number of partitions (default: 1). Use a small number for accuracy. + """ + self.numPartitions = numPartitions + return self + + def setNumIterations(self, numIterations): + """ + Sets number of iterations (default: 1), which should be smaller than or equal to number of + partitions. + """ + self.numIterations = numIterations + return self + + def setSeed(self, seed): + """ + Sets random seed. + """ + self.seed = seed + return self + + def fit(self, data): + """ + Computes the vector representation of each word in vocabulary. + + :param data: training data. RDD of subtype of Iterable[String] + :return: python Word2VecModel instance + """ + sc = data.context + ser = PickleSerializer() + vectorSize = self.vectorSize + learningRate = self.learningRate + numPartitions = self.numPartitions + numIterations = self.numIterations + seed = self.seed + + model = sc._jvm.PythonMLLibAPI().trainWord2Vec( + data._to_java_object_rdd(), vectorSize, + learningRate, numPartitions, numIterations, seed) + return Word2VecModel(sc, model) + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51014a8ceb785..24c5480b2f753 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -238,8 +238,8 @@ def __init__(self, size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print SparseVector(4, {1: 1.0, 3: 5.5}) @@ -458,8 +458,8 @@ def sparse(size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 54f34a98337ca..12b322aaae796 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -31,8 +31,8 @@ class LabeledPoint(object): """ The features and labels of a data point. - @param label: Label for this data point. - @param features: Vector of features for this point (NumPy array, list, + :param label: Label for this data point. + :param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ @@ -145,15 +145,15 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a linear regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. :Allowed values: diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 8233d4e81f1ca..1357fd4fbc8aa 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -77,10 +77,10 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None method parses each line into a LabeledPoint, where the feature indices are converted to zero-based. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param numFeatures: number of features, which will be determined + :param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is already split into multiple files and you @@ -88,7 +88,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None features may not present in certain files, which leads to inconsistent feature dimensions. - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -126,8 +126,8 @@ def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. - @param data: an RDD of LabeledPoint to be saved - @param dir: directory to save the data + :param data: an RDD of LabeledPoint to be saved + :param dir: directory to save the data >>> from tempfile import NamedTemporaryFile >>> from fileinput import input @@ -149,10 +149,10 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e77669aad76b6..6797d50659a92 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -752,7 +752,7 @@ def max(self, key=None): """ Find the maximum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() @@ -768,7 +768,7 @@ def min(self, key=None): """ Find the minimum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() @@ -1115,9 +1115,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1135,16 +1135,16 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop job configuration, passed in as a dict (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1161,9 +1161,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1182,17 +1182,17 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: (None by default) - @param compressionCodecClass: (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: (None by default) + :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1212,8 +1212,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. - @param path: path to sequence file - @param compressionCodecClass: (None by default) + :param path: path to sequence file + :param compressionCodecClass: (None by default) """ pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) @@ -2009,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - @param relativeSD Relative accuracy. Smaller values create + :param relativeSD Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6e7acdd817b9a..08a0f0d8ffb3e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -223,7 +223,7 @@ class AutoBatchedSerializer(BatchedSerializer): Choose the size of batch automatically based on the size of object """ - def __init__(self, serializer, bestSize=1 << 20): + def __init__(self, serializer, bestSize=1 << 16): BatchedSerializer.__init__(self, serializer, -1) self.bestSize = bestSize @@ -250,7 +250,7 @@ def __eq__(self, other): other.serializer == self.serializer) def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer<%s>" % str(self.serializer) class CartesianDeserializer(FramedSerializer): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 114644ab8b79d..d3d36eb995ab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,28 +15,38 @@ # limitations under the License. # +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{SchemaRDD} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" -import sys -import types import itertools -import warnings import decimal import datetime import keyword import warnings +import json from array import array from operator import itemgetter +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from itertools import chain, ifilter, imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -62,6 +72,18 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + class PrimitiveTypeSingleton(type): @@ -205,6 +227,16 @@ def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + class MapType(DataType): @@ -245,6 +277,18 @@ def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + class StructField(DataType): @@ -283,6 +327,17 @@ def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"]) + class StructType(DataType): @@ -312,42 +367,30 @@ def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} -def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types.""" - index = 0 - datatype_list = [] - start = 0 - depth = 0 - while index < len(datatype_list_string): - if depth == 0 and datatype_list_string[index] == ",": - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - start = index + 1 - elif datatype_list_string[index] == "(": - depth += 1 - elif datatype_list_string[index] == ")": - depth -= 1 + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) - index += 1 - # Handle the last data type - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - return datatype_list +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) -_all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) -def _parse_datatype_string(datatype_string): - """Parses the given data type string. - +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string( - ... scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -385,51 +428,14 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_maptype) True """ - index = datatype_string.find("(") - if index == -1: - # It is a primitive type. - index = len(datatype_string) - type_or_field = datatype_string[:index] - rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() - - if type_or_field in _all_primitive_types: - return _all_primitive_types[type_or_field]() - - elif type_or_field == "ArrayType": - last_comma_index = rest_part.rfind(",") - containsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - containsNull = False - elementType = _parse_datatype_string( - rest_part[:last_comma_index].strip()) - return ArrayType(elementType, containsNull) - - elif type_or_field == "MapType": - last_comma_index = rest_part.rfind(",") - valueContainsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - valueContainsNull = False - keyType, valueType = _parse_datatype_list( - rest_part[:last_comma_index].strip()) - return MapType(keyType, valueType, valueContainsNull) - - elif type_or_field == "StructField": - first_comma_index = rest_part.find(",") - name = rest_part[:first_comma_index].strip() - last_comma_index = rest_part.rfind(",") - nullable = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - nullable = False - dataType = _parse_datatype_string( - rest_part[first_comma_index + 1:last_comma_index].strip()) - return StructField(name, dataType, nullable) - - elif type_or_field == "StructType": - # rest_part should be in the format like - # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(") + 1:-1] - fields = _parse_datatype_list(field_list_string) - return StructType(fields) + return _parse_datatype_json_value(json.loads(json_string)) + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode and json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + else: + return _all_complex_types[json_value["type"]].fromJson(json_value) # Mapping Python types to Spark SQL DateType @@ -899,8 +905,8 @@ class SQLContext(object): def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. - @param sparkContext: The SparkContext to wrap. - @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) @@ -983,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc.pythonExec, broadcast_vars, self._sc._javaAccumulator, - str(returnType)) + returnType.json()) def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}. @@ -1119,7 +1125,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): @@ -1209,7 +1215,7 @@ def jsonFile(self, path, schema=None): if schema is None: srdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1279,7 +1285,7 @@ def func(iterator): if schema is None: srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1325,8 +1331,8 @@ class HiveContext(SQLContext): def __init__(self, sparkContext, hiveContext=None): """Create a new HiveContext. - @param sparkContext: The SparkContext to wrap. - @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new HiveContext in the JVM, instead we make all calls to this object. """ SQLContext.__init__(self, sparkContext) @@ -1614,7 +1620,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) + return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) def schemaString(self): """Returns the output schema in the tree format.""" diff --git a/python/run-tests b/python/run-tests index f799902b4c322..2f98443c30aef 100755 --- a/python/run-tests +++ b/python/run-tests @@ -25,16 +25,17 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" cd "$FWDIR/python" FAILED=0 +LOG_FILE=unit-tests.log -rm -f unit-tests.log +rm -f $LOG_FILE # Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { - echo "Running test: $1" + echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -69,6 +70,7 @@ function run_mllib_tests() { echo "Run mllib tests ..." run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/feature.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 6ddb6accd696b..646c68e60c2e9 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -84,9 +84,11 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) - extends SparkImports with Logging { - imain => + class SparkIMain( + initialSettings: Settings, + val out: JPrintWriter, + propagateExceptions: Boolean = false) + extends SparkImports with Logging { imain => val conf = new SparkConf() @@ -816,6 +818,10 @@ import org.apache.spark.util.Utils val resultName = FixedSessionNames.resultName def bindError(t: Throwable) = { + // Immediately throw the exception if we are asked to propagate them + if (propagateExceptions) { + throw unwrap(t) + } if (!bindExceptions) // avoid looping if already binding throw t diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 3e2ee7541f40d..6a79e76a34db8 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -23,8 +23,6 @@ import java.net.{URL, URLClassLoader} import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import com.google.common.io.Files - import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.util.Utils @@ -39,10 +37,8 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { super.beforeAll() - tempDir1 = Files.createTempDir() - tempDir1.deleteOnExit() - tempDir2 = Files.createTempDir() - tempDir2.deleteOnExit() + tempDir1 = Utils.createTempDir() + tempDir2 = Utils.createTempDir() url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c8763eb277052..91c9c52c3c98a 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext import org.apache.commons.lang3.StringEscapeUtils @@ -190,8 +189,7 @@ class ReplSuite extends FunSuite { } test("interacting with files") { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val out = new FileWriter(tempDir + "/input") out.write("Hello world!\n") out.write("What's up?\n") diff --git a/sql/README.md b/sql/README.md index 31f9152344086..c84534da9a3d3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,38 +44,37 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.ExecutedQuery = -SELECT * FROM (SELECT * FROM src) a -=== Query Plan === -Project [key#6:0.0,value#7:0.1] - HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.SchemaRDD = +== Query Plan == +== Physical Plan == +HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None ``` Query results are RDDs and can be operated as such. ``` scala> query.collect() -res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]... +res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these RDDs using the query DSL. ``` -scala> query.where('key === 100).toRdd.collect() -res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100]) +scala> query.where('key === 100).collect() +res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) ``` -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects. +From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. ```scala -scala> query.logicalPlan -res1: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} - Project {key#0,value#1} +scala> query.queryExecution.analyzed +res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] + Project [key#10,value#11] MetastoreRelation default, src, None -scala> query.logicalPlan transform { +scala> query.queryExecution.analyzed transform { | case Project(projectList, child) if projectList == child.output => child | } -res2: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} +res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] MetastoreRelation default, src, None ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala new file mode 100644 index 0000000000000..04467342e6ab5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.input.CharArrayReader.EofCh + +import org.apache.spark.sql.catalyst.plans.logical._ + +private[sql] abstract class AbstractSparkSQLParser + extends StandardTokenParsers with PackratParsers { + + def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } + + protected case class Keyword(str: String) + + protected def start: Parser[LogicalPlan] + + // Returns the whole input string + protected lazy val wholeInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success(in.source.toString, in.drop(in.source.length())) + } + + // Returns the rest of the input string that are not parsed yet + protected lazy val restInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success( + in.source.subSequence(in.offset, in.source.length()).toString, + in.drop(in.source.length())) + } +} + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", "." + ) + + override lazy val token: Parser[Token] = + ( identChar ~ (identChar | digit).* ^^ + { case first ~ rest => processIdent((first :: rest).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} + +/** + * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL + * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. + * + * @param fallback A function that parses an input string to a logical plan + */ +private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { + + // A parser for the key-value part of the "SET [key = [value ]]" syntax + private object SetCommandParser extends RegexParsers { + private val key: Parser[String] = "(?m)[^=]+".r + + private val value: Parser[String] = "(?m).*$".r + + private val pair: Parser[LogicalPlan] = + (key ~ ("=".r ~> value).?).? ^^ { + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) + } + + def apply(input: String): LogicalPlan = parseAll(pair, input) match { + case Success(plan, _) => plan + case x => sys.error(x.toString) + } + } + + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val TABLE = Keyword("TABLE") + protected val SOURCE = Keyword("SOURCE") + protected val UNCACHE = Keyword("UNCACHE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + private val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + override protected lazy val start: Parser[LogicalPlan] = + cache | uncache | set | shell | source | others + + private lazy val cache: Parser[LogicalPlan] = + CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + } + + private lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + + private lazy val set: Parser[LogicalPlan] = + SET ~> restInput ^^ { + case input => SetCommandParser(input) + } + + private lazy val shell: Parser[LogicalPlan] = + "!" ~> restInput ^^ { + case input => ShellCommand(input.trim) + } + + private lazy val source: Parser[LogicalPlan] = + SOURCE ~> restInput ^^ { + case input => SourceCommand(input.trim) + } + + private lazy val others: Parser[LogicalPlan] = + wholeInput ^^ { + case input => fallback(input) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 854b5b461bdc8..b4d606d37e732 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -18,10 +18,6 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -39,31 +35,7 @@ import org.apache.spark.sql.catalyst.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - +class SqlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) @@ -77,10 +49,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CACHE = Keyword("CACHE") + protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val ELSE = Keyword("ELSE") + protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") @@ -97,7 +72,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") protected val LAST = Keyword("LAST") - protected val LAZY = Keyword("LAZY") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") @@ -122,16 +96,18 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") + protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") - protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") + protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") // Use reflection to find the reserved words defined in this class. protected val reservedWords = - this.getClass + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) @@ -145,86 +121,68 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } } - protected lazy val query: Parser[LogicalPlan] = ( - select * ( - UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | - INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } | - EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | - UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + protected lazy val start: Parser[LogicalPlan] = + ( select * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache | unCache - ) + | insert + ) protected lazy val select: Parser[LogicalPlan] = - SELECT ~> opt(DISTINCT) ~ projections ~ - opt(from) ~ opt(filter) ~ - opt(grouping) ~ - opt(having) ~ - opt(orderBy) ~ - opt(limit) <~ opt(";") ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) - val withFilter = f.map(f => Filter(f, base)).getOrElse(base) - val withProjection = - g.map {g => - Aggregate(g, assignAliases(p), withFilter) - }.getOrElse(Project(assignAliases(p), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) - val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder) - withLimit - } + SELECT ~> DISTINCT.? ~ + repsep(projection, ",") ~ + (FROM ~> relations).? ~ + (WHERE ~> expression).? ~ + (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ + (HAVING ~> expression).? ~ + (ORDER ~ BY ~> ordering).? ~ + (LIMIT ~> expression).? ^^ { + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + val base = r.getOrElse(NoRelation) + val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withProjection = g + .map(Aggregate(_, assignAliases(p), withFilter)) + .getOrElse(Project(assignAliases(p), withFilter)) + val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) + val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) + val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving) + val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) + withLimit + } protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ { - case o ~ r ~ s => - val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" - InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val unCache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { - case tableName => UncacheTableCommand(tableName) + INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ { + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined) } - protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") - protected lazy val projection: Parser[Expression] = - expression ~ (opt(AS) ~> opt(ident)) ^^ { - case e ~ None => e - case e ~ Some(a) => Alias(e, a)() + expression ~ (AS.? ~> ident.?) ^^ { + case e ~ a => a.fold(e)(Alias(e, _)()) } - protected lazy val from: Parser[LogicalPlan] = FROM ~> relations - - protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation - // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = - relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } | - relation + ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) } + | relation + ) protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | - relationFactor + joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ident ~ (opt(AS) ~> opt(ident)) ^^ { - case tableName ~ alias => UnresolvedRelation(None, tableName, alias) - } | - "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } + ( ident ~ (opt(AS) ~> opt(ident)) ^^ { + case tableName ~ alias => UnresolvedRelation(None, tableName, alias) + } + | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } + ) protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { - case r1 ~ jt ~ _ ~ r2 ~ cond => + relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ { + case r1 ~ jt ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) } @@ -232,151 +190,145 @@ class SqlParser extends StandardTokenParsers with PackratParsers { ON ~> expression protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter - - protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } - - protected lazy val orderBy: Parser[Seq[SortOrder]] = - ORDER ~> BY ~> ordering + ( INNER ^^^ Inner + | LEFT ~ SEMI ^^^ LeftSemi + | LEFT ~ OUTER.? ^^^ LeftOuter + | RIGHT ~ OUTER.? ^^^ RightOuter + | FULL ~ OUTER.? ^^^ FullOuter + ) protected lazy val ordering: Parser[Seq[SortOrder]] = - rep1sep(singleOrder, ",") | - rep1sep(expression, ",") ~ opt(direction) ^^ { - case exps ~ None => exps.map(SortOrder(_, Ascending)) - case exps ~ Some(d) => exps.map(SortOrder(_, d)) - } + ( rep1sep(singleOrder, ",") + | rep1sep(expression, ",") ~ direction.? ^^ { + case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending))) + } + ) protected lazy val singleOrder: Parser[SortOrder] = - expression ~ direction ^^ { case e ~ o => SortOrder(e,o) } + expression ~ direction ^^ { case e ~ o => SortOrder(e, o) } protected lazy val direction: Parser[SortDirection] = - ASC ^^^ Ascending | - DESC ^^^ Descending - - protected lazy val grouping: Parser[Seq[Expression]] = - GROUP ~> BY ~> rep1sep(expression, ",") - - protected lazy val having: Parser[Expression] = - HAVING ~> expression - - protected lazy val limit: Parser[Expression] = - LIMIT ~> expression + ( ASC ^^^ Ascending + | DESC ^^^ Descending + ) - protected lazy val expression: Parser[Expression] = orExpression + protected lazy val expression: Parser[Expression] = + orExpression protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) }) + andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) protected lazy val andExpression: Parser[Expression] = - comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) + comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | - termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | - termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | - termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | - termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { - case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - } | - termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | - termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ e2 => In(e1, e2) - } | - termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2)) - } | - termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | - termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } | - NOT ~> termExpression ^^ {e => Not(e)} | - termExpression + ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } + | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } + | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } + | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } + | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } + | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { + case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + } + | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => In(e1, e2) + } + | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => Not(In(e1, e2)) + } + | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } + | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } + | NOT ~> termExpression ^^ {e => Not(e)} + | termExpression + ) protected lazy val termExpression: Parser[Expression] = - productExpression * ( - "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } | - "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } ) + productExpression * + ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } + | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } + ) protected lazy val productExpression: Parser[Expression] = - baseExpression * ( - "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | - "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | - "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } - ) + baseExpression * + ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } + | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } + | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } + ) protected lazy val function: Parser[Expression] = - SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } | - SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | - COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | - COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | - COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | - APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { - case exp => ApproxCountDistinct(exp) - } | - APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { - case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) - } | - FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | - LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } | - AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | - MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | - MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | - UPPER ~> "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } | - LOWER ~> "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } | - IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case c ~ "," ~ t ~ "," ~ f => If(c,t,f) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) - } | - SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | - ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | - ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { - case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) - } + ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } + | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } + | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } + | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } + | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } + | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ + { case exp => ApproxCountDistinct(exp) } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } + | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } + | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } + | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } + | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } + | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } + | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } + | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } + | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case c ~ t ~ f => If(c, t, f) } + | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { case whenExpr ~ thenExpr => + Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr) + } + CaseWhen(altExprs ++ elsePart.toList) + } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p ~ l => Substring(s, p, l) } + | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } + | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + ) protected lazy val cast: Parser[Expression] = - CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } protected lazy val literal: Parser[Literal] = - numericLit ^^ { - case i if i.toLong > Int.MaxValue => Literal(i.toLong) - case i => Literal(i.toInt) - } | - NULL ^^^ Literal(null, NullType) | - floatLit ^^ {case f => Literal(f.toDouble) } | - stringLit ^^ {case s => Literal(s, StringType) } + ( numericLit ^^ { + case i if i.toLong > Int.MaxValue => Literal(i.toLong) + case i => Literal(i.toInt) + } + | NULL ^^^ Literal(null, NullType) + | floatLit ^^ {case f => Literal(f.toDouble) } + | stringLit ^^ {case s => Literal(s, StringType) } + ) protected lazy val floatLit: Parser[String] = elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - (expression <~ ".") ~ ident ^^ { - case base ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - dotExpressionHeader | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal + ( expression ~ ("[" ~> expression <~ "]") ^^ + { case base ~ ordinal => GetItem(base, ordinal) } + | (expression <~ ".") ~ ident ^^ + { case base ~ fieldName => GetField(base, fieldName) } + | TRUE ^^^ Literal(true, BooleanType) + | FALSE ^^^ Literal(false, BooleanType) + | cast + | "(" ~> expression <~ ")" + | function + | "-" ~> literal ^^ UnaryMinus + | dotExpressionHeader + | ident ^^ UnresolvedAttribute + | "*" ^^^ Star(None) + | literal + ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { @@ -386,55 +338,3 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } - -class SqlLexical(val keywords: Seq[String]) extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - - reserved ++= keywords.flatMap(w => allCaseVersions(w)) - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", "." - ) - - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 79e5283e86a37..64881854df7a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -348,8 +348,11 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Decimal and Double remain the same - case d: Divide if d.dataType == DoubleType => d - case d: Divide if d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType == DoubleType => d + case d: Divide if d.resolved && d.dataType == DecimalType => d + + case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) + case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index ef1d12531f109..204904ecf04db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -137,6 +137,9 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -226,6 +229,9 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -309,6 +315,9 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -392,6 +401,9 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -475,6 +487,9 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d68a4fabeac77..d00ec39774c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = s"[${this.mkString(",")}]" @@ -118,6 +119,7 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this } @@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } - override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } - override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } override def copy() = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 9cbab3d5d0d0d..570379c533e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def iterator: Iterator[Any] = values.map(_.boxed).iterator - def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String) = update(ordinal, value) - def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] @@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } + + override def getAs[T](i: Int): T = { + values(i).boxed.asInstanceOf[T] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 1eb55715794a7..1a4ac06c7a79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType { - def simpleString: String = "dynamic" -} +case object DynamicType extends DataType /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 329af332d0fa1..1e22b2d03c672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType - object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = apply(BindReferences.bindReference(expression, inputSchema)) @@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } +/** + * Optimized version of In clause, when all filter values of In clause are + * static. + */ +case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) + extends Predicate { + + def children = child + + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + + override def eval(input: Row): Any = { + hset.contains(value.eval(input)) + } +} + case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4133feae8166..3693b41404fd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + OptimizeIn) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -273,6 +275,20 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]] + * which is much faster + */ +object OptimizeIn extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet, v +: list) + } + } +} + /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides. * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus @@ -299,6 +315,18 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, _) => or } + case not @ Not(exp) => + exp match { + case Literal(true, BooleanType) => Literal(false) + case Literal(false, BooleanType) => Literal(true) + case GreaterThan(l, r) => LessThanOrEqual(l, r) + case GreaterThanOrEqual(l, r) => LessThan(l, r) + case LessThan(l, r) => GreaterThanOrEqual(l, r) + case LessThanOrEqual(l, r) => GreaterThan(l, r) + case Not(e) => e + case _ => not + } + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 9a3848cfc6b62..b8ba2ee428a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -39,9 +39,9 @@ case class NativeCommand(cmd: String) extends Command { } /** - * Commands of the form "SET (key) (= value)". + * Commands of the form "SET [key [= value] ]". */ -case class SetCommand(key: Option[String], value: Option[String]) extends Command { +case class SetCommand(kv: Option[(String, Option[String])]) extends Command { override def output = Seq( AttributeReference("", StringType, nullable = false)()) } @@ -81,3 +81,14 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "! shellCommand" command + */ +case class ShellCommand(cmd: String) extends Command + + +/** + * Returned for the "SOURCE file" command + */ +case class SourceCommand(filePath: String) extends Command diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ac043d4dd8eb9..1d375b8754182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} +import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers +import org.json4s.JsonAST.JValue +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils -/** - * Utility functions for working with DataTypes. - */ -object DataType extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "StringType" ^^^ StringType | - "FloatType" ^^^ FloatType | - "IntegerType" ^^^ IntegerType | - "ByteType" ^^^ ByteType | - "ShortType" ^^^ ShortType | - "DoubleType" ^^^ DoubleType | - "LongType" ^^^ LongType | - "BinaryType" ^^^ BinaryType | - "BooleanType" ^^^ BooleanType | - "DecimalType" ^^^ DecimalType | - "TimestampType" ^^^ TimestampType - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) +object DataType { + def fromJson(json: String): DataType = parseDataType(parse(json)) + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + PrimitiveType.nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + } - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + @deprecated("Use DataType.fromJson instead") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DecimalType" ^^^ DecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) - } + } - protected lazy val boolVal: Parser[Boolean] = - "true" ^^^ true | - "false" ^^^ false + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => new StructType(fields) - } + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } protected[types] def buildFormattedString( @@ -111,15 +165,19 @@ abstract class DataType { def isPrimitive: Boolean = false - def simpleString: String -} + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName -case object NullType extends DataType { - def simpleString: String = "null" + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) } +case object NullType extends DataType + object NativeType { - def all = Seq( + val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -139,6 +197,12 @@ trait PrimitiveType extends DataType { override def isPrimitive = true } +object PrimitiveType { + private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all + + private[sql] val nameToType = all.map(t => t.typeName -> t).toMap +} + abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "string" } case object BinaryType extends NativeType with PrimitiveType { @@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType { val res = x(i).compareTo(y(i)) if (res != 0) return res } - return x.length - y.length + x.length - y.length } } - def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "boolean" } case object TimestampType extends NativeType { @@ -187,8 +248,6 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } - - def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -222,7 +281,6 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "long" } case object IntegerType extends IntegralType { @@ -231,7 +289,6 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "integer" } case object ShortType extends IntegralType { @@ -240,7 +297,6 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "short" } case object ByteType extends IntegralType { @@ -249,7 +305,6 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -271,7 +326,6 @@ case object DecimalType extends FractionalType { private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = BigDecimalAsIfIntegral - def simpleString: String = "decimal" } case object DoubleType extends FractionalType { @@ -281,7 +335,6 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral - def simpleString: String = "double" } case object FloatType extends FractionalType { @@ -291,12 +344,12 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral - def simpleString: String = "float" } object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) + def typeName: String = "array" } /** @@ -309,11 +362,14 @@ object ArrayType { case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( - s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") DataType.buildFormattedString(elementType, s"$prefix |", builder) } - def simpleString: String = "array" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) } /** @@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) + } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + + def typeName = "struct" } case class StructType(fields: Seq[StructField]) extends DataType { @@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - nameToField.get(name).getOrElse( - throw new IllegalArgumentException(s"Field ${name} does not exist.")) + nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) } /** @@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet - if (!nonExistFields.isEmpty) { + if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } @@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType { fields.foreach(field => field.buildFormattedString(prefix, builder)) } - def simpleString: String = "struct" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> fields.map(_.jsonValue)) } object MapType { @@ -394,6 +459,8 @@ object MapType { */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, true) + + def simpleName = "map" } /** @@ -407,12 +474,16 @@ case class MapType( valueType: DataType, valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } - def simpleString: String = "map" + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5809a108ff62e..7b45738c4fc95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) @@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) before { caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) @@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[RuntimeException] { caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) } - assert(e.getMessage === "Table Not Found: tAbLe") + assert(e.getMessage == "Table Not Found: tAbLe") assert( caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === @@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } assert(e.getMessage().toLowerCase.contains("unresolved plan")) } + + test("divide should be casted into fractional types") { + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) + + val expr0 = 'a / 2 + val expr1 = 'a / 'b + val expr2 = 'a / 'c + val expr3 = 'a / 'd + val expr4 = 'e / 'e + val plan = caseInsensitiveAnalyze(Project( + Alias(expr0, s"Analyzer($expr0)")() :: + Alias(expr1, s"Analyzer($expr1)")() :: + Alias(expr2, s"Analyzer($expr2)")() :: + Alias(expr3, s"Analyzer($expr3)")() :: + Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) + assert(pl(1).dataType == DoubleType) + assert(pl(2).dataType == DoubleType) + assert(pl(3).dataType == DecimalType) + assert(pl(4).dataType == DoubleType) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 63931af4bac3d..692ed78a7292c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import scala.collection.immutable.HashSet + import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalautils.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ + /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -145,6 +148,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS, one +: s), true) + checkEvaluation(InSet(two, hS, two +: s), true) + checkEvaluation(InSet(two, nS, two +: nullS), true) + checkEvaluation(InSet(nl, nS, nl +: nullS), true) + checkEvaluation(InSet(three, hS, three +: s), false) + checkEvaluation(InSet(three, nS, three +: nullS), false) + checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) + } + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala new file mode 100644 index 0000000000000..97a78ec971c39 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types._ + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class OptimizeInSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + OptimizeIn) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("OptimizedIn test: In clause optimized to InSet") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, + UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: In clause not optimized in case filter has attributes") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3bf7382ac67a6..5ab2b5316ab10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel -import org.apache.spark.storage.StorageLevel.MEMORY_ONLY +import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) @@ -74,10 +74,14 @@ private[sql] trait CacheManager { cachedData.clear() } - /** Caches the data produced by the logical representation of the given schema rdd. */ + /** + * Caches the data produced by the logical representation of the given schema rdd. Unlike + * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing + * the in-memory columnar representation of the underlying table is expensive. + */ private[sql] def cacheQuery( query: SchemaRDD, - storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f6f4cf3b80d41..07e6e2eccddf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -35,6 +35,7 @@ private[spark] object SQLConf { val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -131,6 +132,9 @@ private[sql] trait SQLConf { private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def columnNameOfCorruptRecord: String = + getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7a55c5bf97a71..23e7b2d270777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.SparkStrategies +import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: @@ -66,12 +66,17 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = true) + @transient protected[sql] val optimizer = Optimizer + @transient - protected[sql] val parser = new catalyst.SqlParser + protected[sql] val sqlParser = { + val fallback = new catalyst.SqlParser + new catalyst.SparkSQLParser(fallback(_)) + } - protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -195,9 +200,12 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord val appliedSchema = - Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + Option(schema).getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } @@ -206,8 +214,11 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord + val appliedSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } @@ -409,8 +420,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It is only used by PySpark. */ private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) + DataType.fromJson(dataTypeString) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 594bf8ffc20e1..948122d42f0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -360,7 +360,7 @@ class SchemaRDD( join: Boolean = false, outer: Boolean = false, alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) + new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) /** * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index c006c4330ff66..f8171c3be3207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -148,8 +148,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * It goes through the entire dataset once to determine the schema. */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { - val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord + val appliedScalaSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord)) + val scalaRowRDD = + JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) @@ -162,10 +166,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { */ @Experimental def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( - JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType] + val scalaRowRDD = JsonRDD.jsonStringToRow( + json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 4f79173a26f88..22ab0e2613f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -38,7 +38,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } -private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -91,7 +91,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) - CachedBatch(columnBuilders.map(_.build()), stats) + CachedBatch(columnBuilders.map(_.build().array()), stats) } def hasNext = rowIterator.hasNext @@ -238,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan( def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors - val columnAccessors = - requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + val columnAccessors = requestedColumnIndices.map { batch => + ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch))) + } // Extract rows via column accessors new Iterator[Row] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index c386fd121c5de..38877c28de3a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -39,7 +39,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { + // This must be a val since the generator output expr ids are not preserved by serialization. + protected val generatorOutput: Seq[Attribute] = { if (join && outer) { generator.output.map(_.withNullability(true)) } else { @@ -62,7 +63,7 @@ case class Generate( newProjection(child.output ++ nullValues, child.output) val joinProjection = - newProjection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c16d0c624128..4f1af7234d551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ + private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = execution.LeftSemiJoinHash( + val semiJoin = joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition) :: Nil + joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: BuildSide) = { - val broadcastHashJoin = execution.BroadcastHashJoin( + side: joins.BuildSide) = { + val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight + joins.BuildRight } else { - BuildLeft + joins.BuildLeft } - val hashJoin = - execution.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + val hashJoin = joins.ShuffledHashJoin( + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - execution.HashOuterJoin( + joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft - execution.BroadcastNestedLoopJoin( + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } @@ -274,9 +277,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => + val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -300,8 +304,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.SetCommand(key, value) => - Seq(execution.SetCommand(key, value, plan.output)(context)) + case logical.SetCommand(kv) => + Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheTableCommand(tableName, optPlan, isLazy) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index d49633c24ad4d..5859eba408ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -48,29 +48,28 @@ trait Command { * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - key: Option[String], value: Option[String], output: Seq[Attribute])( +case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected lazy val sideEffectResult: Seq[Row] = (key, value) match { - // Set value for key k. - case (Some(k), Some(v)) => - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + override protected lazy val sideEffectResult: Seq[Row] = kv match { + // Set value for the key. + case Some((key, Some(value))) => + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) } else { - context.setConf(k, v) - Seq(Row(s"$k=$v")) + context.setConf(key, value) + Seq(Row(s"$key=$value")) } - // Query the value bound to key k. - case (Some(k), _) => + // Query the value bound to the key. + case Some((key, None)) => // TODO (lian) This is just a workaround to make the Simba ODBC driver work. // Should remove this once we get the ODBC driver updated. - if (k == "-v") { + if (key == "-v") { val hiveJars = Seq( "hive-exec-0.12.0.jar", "hive-service-0.12.0.jar", @@ -84,23 +83,20 @@ case class SetCommand( Row("system:java.class.path=" + hiveJars), Row("system:sun.java.command=shark.SharkServer2")) } else { - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) } else { - Seq(Row(s"$k=${context.getConf(k, "")}")) + Seq(Row(s"$key=${context.getConf(key, "")}")) } } // Query all key-value pairs that are set in the SQLConf of the context. - case (None, None) => + case _ => context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq - - case _ => - throw new IllegalArgumentException() } override def otherCopyArgs = context :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala deleted file mode 100644 index 2890a563bed48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.{HashMap => JavaHashMap} - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.collection.CompactBuffer - -@DeveloperApi -sealed abstract class BuildSide - -@DeveloperApi -case object BuildLeft extends BuildSide - -@DeveloperApi -case object BuildRight extends BuildSide - -trait HashJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val buildSide: BuildSide - val left: SparkPlan - val right: SparkPlan - - lazy val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } - - def output = left.output ++ right.output - - @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - newMutableProjection(streamedKeys, streamedPlan.output) - - def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) - - override final def next() = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - ret - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } - } - - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } - } - - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() - while (iter.hasNext) { - val currentRow = iter.next() - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } - - def execute() = { - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - - import scala.collection.JavaConversions._ - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) - } - } -} - -/** - * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -@DeveloperApi -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - val buildSide = BuildRight - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = left.output - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - } - } -} - - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - @transient - val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) - } - - def execute() = { - val broadcastRelation = Await.result(broadcastFuture, 5.minute) - - streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) - } - } -} - -/** - * :: DeveloperApi :: - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -@DeveloperApi -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - def output = left.output - - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - matched - }) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - def output = left.output ++ right.output - - def execute() = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. - - /** BuildRight means the right relation <=> the broadcast relation. */ - val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - - streamedIter.foreach { streamedRow => - var i = 0 - var streamRowMatched = false - - while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => - } - i += 1 - } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() - case _ => - } - } - Iterator((matchedRows, includedBroadcastTuples)) - } - - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() - var i = 0 - val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) - case _ => - } - } - i += 1 - } - buf.toSeq - } - - // TODO: Breaks lineage. - sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala new file mode 100644 index 0000000000000..d88ab6367a1b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + private val broadcastFuture = future { + sparkContext.broadcast(buildPlan.executeCollect()) + } + + override def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala new file mode 100644 index 0000000000000..36aad13778bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class BroadcastNestedLoopJoin( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + // TODO: Override requiredChildDistribution. + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new CompactBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + + streamedIter.foreach { streamedRow => + var i = 0 + var streamRowMatched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + includedBroadcastTuples.reduce(_ ++ _) + } + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val buf: CompactBuffer[Row] = new CompactBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case _ => + } + } + i += 1 + } + buf.toSeq + } + + // TODO: Breaks lineage. + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala new file mode 100644 index 0000000000000..76c14c02aab34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def output = left.output ++ right.output + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala new file mode 100644 index 0000000000000..472b2e6ca6b4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.CompactBuffer + + +trait HashJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan + + protected lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output = left.output ++ right.output + + @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) + @transient protected lazy val streamSideKeyGenerator = + newMutableProjection(streamedKeys, streamedPlan.output) + + protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala new file mode 100644 index 0000000000000..b73041d306b36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + + existingMatchList += currentRow.copy() + } + + hashTable + } + + override def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala new file mode 100644 index 0000000000000..60003d1900d85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = left.output + + /** The Streamed Relation */ + override def left = streamed + /** The Broadcast relation */ + override def right = broadcast + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala new file mode 100644 index 0000000000000..ea7babf3be948 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..8247304c1dc2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +@DeveloperApi +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala new file mode 100644 index 0000000000000..7f2ab1765b28f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Physical execution operators for join operations. + */ +package object joins { + + @DeveloperApi + sealed abstract class BuildSide + + @DeveloperApi + case object BuildRight extends BuildSide + + @DeveloperApi + case object BuildLeft extends BuildSide + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0f27fd13e7379..61ee960aad9d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.json import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal +import java.sql.Timestamp +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -34,16 +36,19 @@ private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( json: RDD[String], - schema: StructType): RDD[Row] = { - parseJson(json).map(parsed => asRow(parsed, schema)) + schema: StructType, + columnNameOfCorruptRecords: String): RDD[Row] = { + parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) } private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): StructType = { + samplingRatio: Double = 1.0, + columnNameOfCorruptRecords: String): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) + val allKeys = + parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) createSchema(allKeys) } @@ -273,7 +278,9 @@ private[sql] object JsonRDD extends Logging { case atom => atom } - private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + private def parseJson( + json: RDD[String], + columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], // ObjectMapper will not return BigDecimal when // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled @@ -288,12 +295,16 @@ private[sql] object JsonRDD extends Logging { // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() iter.flatMap { record => - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - } + try { + val parsed = mapper.readValue(record, classOf[Object]) match { + case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil + case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + } - parsed + parsed + } catch { + case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + } } }) } @@ -361,6 +372,14 @@ private[sql] object JsonRDD extends Logging { } } + private def toTimestamp(value: Any): Timestamp = { + value match { + case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) + case value: java.lang.Long => new Timestamp(value) + case value: java.lang.String => Timestamp.valueOf(value) + } + } + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ if (value == null) { null @@ -377,6 +396,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case TimestampType => toTimestamp(value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2941b9793597f..e6389cf77a4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging { } def convertFromString(string: String): Seq[Attribute] = { - DataType(string) match { + Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - StructType.fromAttributes(schema).toString + StructType.fromAttributes(schema).json } def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e624f97004f5..444bc95009c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -55,10 +55,10 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") - cacheTable("bigData") - assert(table("bigData").count() === 1000000L) - uncacheTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(table("bigData").count() === 200000L) + table("bigData").unpersist() } test("calling .cache() should use in-memory columnar caching") { @@ -69,7 +69,7 @@ class CachedTableSuite extends QueryTest { test("calling .unpersist() should drop in-memory columnar cache") { table("testData").cache() table("testData").count() - table("testData").unpersist(true) + table("testData").unpersist(blocking = true) assertCached(table("testData"), 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 8fb59c5830f6d..100ecb45e9e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.types.DataType + class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite { struct(Set("b", "d", "e", "f")) } } + + def checkDataTypeJsonRepr(dataType: DataType): Unit = { + test(s"JSON - $dataType") { + assert(DataType.fromJson(dataType.json) === dataType) + } + } + + checkDataTypeJsonRepr(BooleanType) + checkDataTypeJsonRepr(ByteType) + checkDataTypeJsonRepr(ShortType) + checkDataTypeJsonRepr(IntegerType) + checkDataTypeJsonRepr(LongType) + checkDataTypeJsonRepr(FloatType) + checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(TimestampType) + checkDataTypeJsonRepr(StringType) + checkDataTypeJsonRepr(BinaryType) + checkDataTypeJsonRepr(ArrayType(DoubleType, true)) + checkDataTypeJsonRepr(ArrayType(StringType, false)) + checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) + checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeJsonRepr( + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index d001abb7e1fcc..45e58afe9d9a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -147,6 +147,14 @@ class DslQuerySuite extends QueryTest { (1, 1, 1, 2) :: Nil) } + test("SPARK-3858 generator qualifiers are discarded") { + checkAnswer( + arrayData.as('ad) + .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) + .select("ex.data".attr), + Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + } + test("average") { checkAnswer( testData2.groupBy()(avg('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6c7697ece8c56..07f4d2946c1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..a94022c0cf6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone @@ -42,7 +42,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { TimeZone.setDefault(origZone) } - test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), @@ -61,7 +60,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 4) } - test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), @@ -680,9 +678,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + } + + test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { + checkAnswer( + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + } + + test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { + checkAnswer( + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 69e0adbd3ee0d..f53acc8c9f718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -67,10 +67,11 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) // With unsupported predicate checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) def checkBatchPruning( filter: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bfbf431a11913..f14ffca0e4d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 685e788207725..7bb08f1b513ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,8 +21,12 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import java.sql.Timestamp + class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -50,6 +54,12 @@ class JsonSuite extends QueryTest { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(new Timestamp(intNumber.toLong), + enforceCorrectType(intNumber.toLong, TimestampType)) + val strDate = "2014-09-30 12:34:56" + checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType)) } test("Get compatible type") { @@ -636,7 +646,65 @@ class JsonSuite extends QueryTest { ("str_a_1", null, null) :: ("str_a_2", null, null) :: (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") ::Nil + ("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + } + + test("Corrupt records") { + // Test if we can query corrupt records. + val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val jsonSchemaRDD = jsonRDD(corruptRecords) + jsonSchemaRDD.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonSchemaRDD.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + (null, null, null, "{") :: + (null, null, null, "") :: + (null, null, null, """{"a":1, b:2}""") :: + (null, null, null, """{"a":{, b:3}""") :: + ("str_a_4", "str_b_4", "str_c_4", null) :: + (null, null, null, "]") :: Nil ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + ("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Seq("{") :: + Seq("") :: + Seq("""{"a":1, b:2}""") :: + Seq("""{"a":{, b:3}""") :: + Seq("]") :: Nil + ) + + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index fc833b8b54e4c..eaca9f0508a12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -143,4 +143,13 @@ object TestJsonData { """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) + + val corruptRecords = + TestSQLContext.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a":1, b:2}""" :: + """{"a":{, b:3}""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """]""" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 07adf731405af..25e41ecf28e2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } - + test("Querying on empty parquet throws exception (SPARK-3536)") { val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) @@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1.size === 0) Utils.deleteRecursively(tmpdir) } + + test("DataType string parser compatibility") { + val schema = StructType(List( + StructField("c1", IntegerType, false), + StructField("c2", BinaryType, false))) + + val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) + val fromJson = ParquetTypesConverter.convertFromString(schema.json) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 910174a153768..accf61576b804 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -172,7 +172,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => sessionToActivePool(parentSession) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3475c2c9db080..d68dd090b5e6c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -62,9 +62,11 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def captureOutput(source: String)(line: String) { buffer += s"$source> $line" - if (line.contains(expectedAnswers(next.get()))) { - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) + if (next.get() < expectedAnswers.size) { + if (line.startsWith(expectedAnswers(next.get()))) { + if (next.incrementAndGet() == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index c5844e92eaaa9..430ffb29989ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -18,118 +18,50 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers + import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} /** - * A parser that recognizes all HiveQL constructs together with several Spark SQL specific - * extensions like CACHE TABLE and UNCACHE TABLE. + * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ -private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else if (input.trim.startsWith("!")) { - ShellCommand(input.drop(1)) - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - - protected val ADD = Keyword("ADD") - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SOURCE = Keyword("SOURCE") - protected val TABLE = Keyword("TABLE") - protected val UNCACHE = Keyword("UNCACHE") - +private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - protected def allCaseConverse(k: String): Parser[String] = - lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _) + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") + protected val FILE = Keyword("FILE") + protected val JAR = Keyword("JAR") - protected val reservedWords = - this.getClass + private val reservedWords = + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) override val lexical = new SqlLexical(reservedWords) - protected lazy val query: Parser[LogicalPlan] = - cache | uncache | addJar | addFile | dfs | source | hiveQl + protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim()) - } - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success( - in.source.subSequence(in.offset, in.source.length).toString, - in.drop(in.source.length())) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) + case statement => HiveQl.createPlan(statement.trim) } - protected lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case jar => AddJar(jar.trim()) + protected lazy val dfs: Parser[LogicalPlan] = + DFS ~> wholeInput ^^ { + case command => NativeCommand(command.trim) } - protected lazy val addFile: Parser[LogicalPlan] = + private lazy val addFile: Parser[LogicalPlan] = ADD ~ FILE ~> restInput ^^ { - case file => AddFile(file.trim()) + case input => AddFile(input.trim) } - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => NativeCommand(command.trim()) - } - - protected lazy val source: Parser[LogicalPlan] = - SOURCE ~> restInput ^^ { - case file => SourceCommand(file.trim()) + private lazy val addJar: Parser[LogicalPlan] = + ADD ~ JAR ~> restInput ^^ { + case input => AddJar(input.trim) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index cc0605b0adb35..addd5bed8426d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,31 +19,28 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} +import org.apache.spark.sql.catalyst.analysis.Catalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { - import HiveMetastoreTypes._ + import org.apache.spark.sql.hive.HiveMetastoreTypes._ /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) @@ -137,10 +134,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { val childOutputDataTypes = child.output.map(_.dataType) - // Only check attributes, not partitionKeys since they are always strings. - // TODO: Fully support inserting into partitioned tables. val tableOutputDataTypes = - table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) + (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 32c9175f181bb..7cc14dc7a9c9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.spark.sql.catalyst.SparkSQLParser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -38,10 +39,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -private[hive] case class ShellCommand(cmd: String) extends Command - -private[hive] case class SourceCommand(filePath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command private[hive] case class AddJar(path: String) extends Command @@ -126,9 +123,11 @@ private[hive] object HiveQl { "TOK_CREATETABLE", "TOK_DESCTABLE" ) ++ nativeCommands - - // It parses hive sql query along with with several Spark SQL specific extensions - protected val hiveSqlParser = new ExtendedHiveQlParser + + protected val hqlParser = { + val fallback = new ExtendedHiveQlParser + new SparkSQLParser(fallback(_)) + } /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -218,7 +217,7 @@ private[hive] object HiveQl { def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser(sql) /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String) = { @@ -639,7 +638,7 @@ private[hive] object HiveQl { def nodeToRelation(node: Node): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(alias, nodeToPlan(query)) + Subquery(cleanIdentifier(alias), nodeToPlan(query)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => val Token("TOK_SELECT", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 508d8239c7628..5c66322f1ed99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -167,10 +167,10 @@ private[hive] trait HiveStrategies { database.get, tableName, query, - InsertIntoHiveTable(_: MetastoreRelation, - Map(), - query, - true)(hiveContext)) :: Nil + InsertIntoHiveTable(_: MetastoreRelation, + Map(), + query, + overwrite = true)(hiveContext)) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f8b4e898ec41d..f0785d8882636 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,33 +69,36 @@ case class InsertIntoHiveTable( * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) - - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) - - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) - case (obj, _) => - obj + case _ => + identity[Any] } def saveAsHiveFile( @@ -103,7 +106,7 @@ case class InsertIntoHiveTable( valueClass: Class[_], fileSinkConf: FileSinkDesc, conf: SerializableWritable[JobConf], - writerContainer: SparkHiveWriterContainer) { + writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -122,7 +125,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -131,6 +134,7 @@ case class InsertIntoHiveTable( .asInstanceOf[StructObjectInspector] val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it @@ -141,13 +145,13 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` - outputData(i) = wrap(row(i), fieldOIs(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) i += 1 } - val writer = writerContainer.getLocalFileWriter(row) - writer.write(serializer.serialize(outputData, standardOI)) + writerContainer + .getLocalFileWriter(row) + .write(serializer.serialize(outputData, standardOI)) } writerContainer.close() @@ -207,7 +211,7 @@ case class InsertIntoHiveTable( // Report error if any static partition appears after a dynamic partition val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java new file mode 100644 index 0000000000000..6c4f378bc5471 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFIntegerToString extends UDF { + public String evaluate(Integer i) { + return i.toString(); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java new file mode 100644 index 0000000000000..d2d39a8c4dc28 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; + +public class UDFListListInt extends UDF { + /** + * + * @param obj + * SQL schema: array> + * Java Type: List> + * @return + */ + public long evaluate(Object obj) { + if (obj == null) { + return 0l; + } + List listList = (List) obj; + long retVal = 0; + for (List aList : listList) { + @SuppressWarnings("unchecked") + List list = (List) aList; + @SuppressWarnings("unchecked") + Integer someInt = (Integer) list.get(1); + try { + retVal += (long) (someInt.intValue()); + } catch (NullPointerException e) { + System.out.println(e); + } + } + return retVal; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java new file mode 100644 index 0000000000000..efd34df293c88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; +import org.apache.commons.lang.StringUtils; + +public class UDFListString extends UDF { + + public String evaluate(Object a) { + if (a == null) { + return null; + } + @SuppressWarnings("unchecked") + List s = (List) a; + + return StringUtils.join(s, ','); + } + + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java new file mode 100644 index 0000000000000..a369188d471e8 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFStringString extends UDF { + public String evaluate(String s1, String s2) { + return s1 + " " + s2; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java new file mode 100644 index 0000000000000..0165591a7ce78 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFTwoListList extends UDF { + public String evaluate(Object o1, Object o2) { + UDFListListInt udf = new UDFListListInt(); + + return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2)); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a35c40efdc207..14e791fe0f0ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2e282a9ade40c..2829105f43716 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -675,6 +676,41 @@ class HiveQuerySuite extends HiveComparisonTest { sql("SELECT * FROM boom").queryExecution.analyzed } + test("SPARK-3810: PreInsertionCasts static partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + + test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index e4324e9528f9b..872f28d514efe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -17,33 +17,37 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataOutput, DataInput} +import java.io.{DataInput, DataOutput} import java.util import java.util.Properties -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} - -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject - -import org.apache.spark.sql.Row +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.io.Writable +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) +// Case classes for the custom UDF's. +case class IntegerCaseClass(i: Int) +case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) +case class StringCaseClass(s: String) +case class ListStringCaseClass(l: Seq[String]) + /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends HiveComparisonTest { +class HiveUdfSuite extends QueryTest { + import TestHive._ test("spark sql udf test that returns a struct") { registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest { } test("SPARK-2693 udaf aggregates test") { - assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) + } + + test("UDFIntegerToString") { + val testData = TestHive.sparkContext.parallelize( + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + testData.registerTempTable("integerTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + checkAnswer( + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + Seq(Seq("1"), Seq("2"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + + TestHive.reset() + } + + test("UDFListListInt") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + testData.registerTempTable("listListIntTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + checkAnswer( + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + Seq(Seq(0), Seq(2), Seq(13))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + + TestHive.reset() + } + + test("UDFListString") { + val testData = TestHive.sparkContext.parallelize( + ListStringCaseClass(Seq("a", "b", "c")) :: + ListStringCaseClass(Seq("d", "e")) :: Nil) + testData.registerTempTable("listStringTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + checkAnswer( + sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + Seq(Seq("a,b,c"), Seq("d,e"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + + TestHive.reset() + } + + test("UDFStringString") { + val testData = TestHive.sparkContext.parallelize( + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + testData.registerTempTable("stringTable") + + sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + checkAnswer( + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + Seq(Seq("hello world"), Seq("hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + + TestHive.reset() + } + + test("UDFTwoListList") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: + Nil) + testData.registerTempTable("TwoListTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + checkAnswer( + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + + TestHive.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3647bb1c4ce7d..fbe6ac765c009 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -68,5 +68,11 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), sql("SELECT `key` FROM src").collect().toSeq) - } + } + + test("SPARK-3834 Backticks not correctly handled in subquery aliases") { + checkAnswer( + sql("SELECT a.key FROM (SELECT key FROM src) `a`"), + sql("SELECT `key` FROM src").collect().toSeq) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8511390cb1ad5..e5592e52b0d2d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -231,8 +231,7 @@ class CheckpointSuite extends TestSuiteBase { // failure, are re-processed or not. test("recovery with file input stream") { // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() var ssc = new StreamingContext(master, framework, Seconds(1)) ssc.checkpoint(checkpointDir) val fileStream = ssc.textFileStream(testDir.toString) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..fa04fa326e370 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -98,8 +96,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() val ssc = new StreamingContext(conf, batchDuration) val fileStream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] @@ -144,59 +141,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -378,22 +322,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index c53c01706083a..5dbb7232009eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -352,8 +352,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) extends Thread with Logging { override def run() { - val localTestDir = Files.createTempDir() - localTestDir.deleteOnExit() + val localTestDir = Utils.createTempDir() var fs = testDir.getFileSystem(new Configuration()) val maxTries = 3 try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 759baacaa4308..9327ff4822699 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, FunSuite} -import com.google.common.io.Files import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -120,9 +120,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Directory where the checkpoint data will be saved lazy val checkpointDir = { - val dir = Files.createTempDir() + val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") - dir.deleteOnExit() dir.toString } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 6c93d8582330b..abd37834ed3cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -43,7 +43,7 @@ private[yarn] class YarnAllocationHandler( args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) - extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { + extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) { private val lastResponseId = new AtomicInteger() private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6ecac6eae6e03..14a0386b78978 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} import scala.util.{Try, Success, Failure} +import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission @@ -64,12 +65,12 @@ private[spark] trait ClientBase extends Logging { s"memory capability of the cluster ($maxMem MB per container)") val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( @@ -771,15 +772,17 @@ private[spark] object ClientBase extends Logging { private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { val srcUri = srcFs.getUri() val dstUri = destFs.getUri() - if (srcUri.getScheme() == null) { - return false - } - if (!srcUri.getScheme().equals(dstUri.getScheme())) { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() - if ((srcHost != null) && (dstHost != null)) { + + // In HA or when using viewfs, the host part of the URI may not actually be a host, but the + // name of the HDFS namespace. Those names won't resolve, so avoid even trying if they + // match. + if (srcHost != null && dstHost != null && srcHost != dstHost) { try { srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() @@ -787,19 +790,9 @@ private[spark] object ClientBase extends Logging { case e: UnknownHostException => return false } - if (!srcHost.equals(dstHost)) { - return false - } - } else if (srcHost == null && dstHost != null) { - return false - } else if (srcHost != null && dstHost == null) { - return false - } - if (srcUri.getPort() != dstUri.getPort()) { - false - } else { - true } + + Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() } } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 9bd916100dd2c..17b79ae1d82c4 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -20,13 +20,10 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ @@ -117,7 +114,7 @@ class ClientBaseSuite extends FunSuite with Matchers { doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort(), anyBoolean()) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath()) sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 97eb0548e77c3..fe55d70ccc370 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -41,4 +41,55 @@ + + + + hadoop-2.2 + + 1.9 + + + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + + + + +