diff --git a/.rat-excludes b/.rat-excludes index 769defbac11b7..a788e8273d8a2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,4 +1,5 @@ target +cache .gitignore .gitattributes .project diff --git a/assembly/pom.xml b/assembly/pom.xml index c1bcdbb664dd0..fbb6e94839d42 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -36,10 +36,6 @@ scala-${scala.binary.version} spark-assembly-${project.version}-hadoop${hadoop.version}.jar ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} - spark - /usr/share/spark - root - 755 @@ -118,6 +114,16 @@ META-INF/*.RSA + + + org.jblas:jblas + + + lib/Linux/i386/** + lib/Mac OS X/** + lib/Windows/** + + @@ -217,131 +223,6 @@ - - deb - - - - maven-antrun-plugin - - - prepare-package - - run - - - - - NOTE: Debian packaging is deprecated and is scheduled to be removed in Spark 1.4. - - - - - - - - org.codehaus.mojo - buildnumber-maven-plugin - 1.2 - - - validate - - create - - - 8 - - - - - - org.vafer - jdeb - 0.11 - - - package - - jdeb - - - ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb - false - gzip - - - ${spark.jar} - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/jars - - - - ${basedir}/src/deb/RELEASE - file - - perm - ${deb.user} - ${deb.user} - ${deb.install.path} - - - - ${basedir}/../conf - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/conf - ${deb.bin.filemode} - - - - ${basedir}/../bin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/bin - ${deb.bin.filemode} - - - - ${basedir}/../sbin - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/sbin - ${deb.bin.filemode} - - - - ${basedir}/../python - directory - - perm - ${deb.user} - ${deb.user} - ${deb.install.path}/python - ${deb.bin.filemode} - - - - - - - - - - kinesis-asl diff --git a/assembly/src/deb/RELEASE b/assembly/src/deb/RELEASE deleted file mode 100644 index aad50ee73aa45..0000000000000 --- a/assembly/src/deb/RELEASE +++ /dev/null @@ -1,2 +0,0 @@ -compute-classpath.sh uses the existence of this file to decide whether to put the assembly jar on the -classpath or instead to use classfiles in the source tree. \ No newline at end of file diff --git a/assembly/src/deb/control/control b/assembly/src/deb/control/control deleted file mode 100644 index a6b4471d485f4..0000000000000 --- a/assembly/src/deb/control/control +++ /dev/null @@ -1,8 +0,0 @@ -Package: [[deb.pkg.name]] -Version: [[version]]-[[buildNumber]] -Section: misc -Priority: extra -Architecture: all -Maintainer: Matei Zaharia -Description: [[name]] -Distribution: development diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index a8c344b1ca594..f4f6b7b909490 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -76,7 +76,7 @@ fi num_jars=0 -for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do +for f in "${assembly_folder}"/spark-assembly*hadoop*.jar; do if [[ ! -e "$f" ]]; then echo "Failed to find Spark assembly in $assembly_folder" 1>&2 echo "You need to build Spark before running this program." 1>&2 @@ -88,7 +88,7 @@ done if [ "$num_jars" -gt "1" ]; then echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2 - ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2 + ls "${assembly_folder}"/spark-assembly*hadoop*.jar 1>&2 echo "Please remove all but one jar." 1>&2 exit 1 fi diff --git a/bin/run-example b/bin/run-example index c567acf9a6b5c..a106411392e06 100755 --- a/bin/run-example +++ b/bin/run-example @@ -42,7 +42,7 @@ fi JAR_COUNT=0 -for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do +for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do if [[ ! -e "$f" ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -54,7 +54,7 @@ done if [ "$JAR_COUNT" -gt "1" ]; then echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2 - ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2 + ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2 echo "Please remove all but one jar." 1>&2 exit 1 fi diff --git a/core/pom.xml b/core/pom.xml index 66180035e61f1..c993781c0e0d6 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -329,16 +329,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - - - asm - asm - test - junit junit diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 68b33b5f0d7c7..6c37cc8b98236 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -196,7 +196,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ -.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time, -.getting_result_time { +.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, +.serialization_time, .getting_result_time { display: none; } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 80cc0587286b1..54399e99c98f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -37,6 +37,7 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} +import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.executor._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -83,14 +84,25 @@ object SparkSubmit { // Exposed for testing private[spark] var exitFn: () => Unit = () => System.exit(1) private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String) = { + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") exitFn() } + private[spark] def printVersionAndExit(): Unit = { + printStream.println("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + printStream.println("Type --help for more information.") + exitFn() + } - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { printStream.println(appArgs) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fa38070c6fcfe..82e66a374249c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -417,6 +417,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St verbose = true parse(tail) + case ("--version") :: tail => + SparkSubmit.printVersionAndExit() + case EQ_SEPARATED_OPT(opt, value) :: tail => parse(opt :: value :: tail) @@ -485,6 +488,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output + | --version, Print the version of current Spark | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 868c63d30a202..885fa0fdbf85b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -247,6 +247,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = { val logPath = eventLog.getPath() + logInfo(s"Replaying log path: $logPath") val (logInput, sparkVersion) = if (isLegacyLogDirectory(eventLog)) { openLegacyEventLog(logPath) @@ -256,7 +257,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val appListener = new ApplicationEventListener bus.addListener(appListener) - bus.replay(logInput, sparkVersion) + bus.replay(logInput, sparkVersion, logPath.toString) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 53e453990f8c7..8cc6ec1e8192c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -761,7 +761,7 @@ private[spark] class Master( val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") try { - replayBus.replay(logInput, sparkVersion) + replayBus.replay(logInput, sparkVersion, eventLogFile) } finally { logInput.close() } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 115aa5278bb62..c4be1f19e8e9f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -19,10 +19,11 @@ package org.apache.spark.deploy.rest import java.io.{DataOutputStream, FileNotFoundException} import java.net.{HttpURLConnection, SocketException, URL} +import javax.servlet.http.HttpServletResponse import scala.io.Source -import com.fasterxml.jackson.databind.JsonMappingException +import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} @@ -155,10 +156,21 @@ private[spark] class StandaloneRestClient extends Logging { /** * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]]. * If the response represents an error, report the embedded message to the user. + * Exposed for testing. */ - private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { + private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { try { - val responseJson = Source.fromInputStream(connection.getInputStream).mkString + val dataStream = + if (connection.getResponseCode == HttpServletResponse.SC_OK) { + connection.getInputStream + } else { + connection.getErrorStream + } + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString logDebug(s"Response from the server:\n$responseJson") val response = SubmitRestProtocolMessage.fromJson(responseJson) response.validate() @@ -177,7 +189,7 @@ private[spark] class StandaloneRestClient extends Logging { case unreachable @ (_: FileNotFoundException | _: SocketException) => throw new SubmitRestConnectionException( s"Unable to connect to server ${connection.getURL}", unreachable) - case malformed @ (_: SubmitRestProtocolException | _: JsonMappingException) => + case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => throw new SubmitRestProtocolException( "Malformed response received from server", malformed) } @@ -284,7 +296,27 @@ private[spark] object StandaloneRestClient { val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" - /** Submit an application, assuming Spark parameters are specified through system properties. */ + /** + * Submit an application, assuming Spark parameters are specified through the given config. + * This is abstracted to its own method for testing purposes. + */ + private[rest] def run( + appResource: String, + mainClass: String, + appArgs: Array[String], + conf: SparkConf, + env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + val master = conf.getOption("spark.master").getOrElse { + throw new IllegalArgumentException("'spark.master' must be set.") + } + val sparkProperties = conf.getAll.toMap + val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } + val client = new StandaloneRestClient + val submitRequest = client.constructSubmitRequest( + appResource, mainClass, appArgs, sparkProperties, environmentVariables) + client.createSubmission(master, submitRequest) + } + def main(args: Array[String]): Unit = { if (args.size < 2) { sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") @@ -294,14 +326,6 @@ private[spark] object StandaloneRestClient { val mainClass = args(1) val appArgs = args.slice(2, args.size) val conf = new SparkConf - val master = conf.getOption("spark.master").getOrElse { - throw new IllegalArgumentException("'spark.master' must be set.") - } - val sparkProperties = conf.getAll.toMap - val environmentVariables = sys.env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new StandaloneRestClient - val submitRequest = client.constructSubmitRequest( - appResource, mainClass, appArgs, sparkProperties, environmentVariables) - client.createSubmission(master, submitRequest) + run(appResource, mainClass, appArgs, conf) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index acd3a2b5abe6c..f9e0478e4f874 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -17,15 +17,14 @@ package org.apache.spark.deploy.rest -import java.io.{DataOutputStream, File} +import java.io.File import java.net.InetSocketAddress import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source import akka.actor.ActorRef -import com.fasterxml.jackson.databind.JsonMappingException -import com.google.common.base.Charsets +import com.fasterxml.jackson.core.JsonProcessingException import org.eclipse.jetty.server.Server import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} import org.eclipse.jetty.util.thread.QueuedThreadPool @@ -70,14 +69,14 @@ private[spark] class StandaloneRestServer( import StandaloneRestServer._ private var _server: Option[Server] = None - private val baseContext = s"/$PROTOCOL_VERSION/submissions" - - // A mapping from servlets to the URL prefixes they are responsible for - private val servletToContext = Map[StandaloneRestServlet, String]( - new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*", - new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*", - new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*", - new ErrorServlet -> "/*" // default handler + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/$PROTOCOL_VERSION/submissions" + protected val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), + s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), + s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), + "/*" -> new ErrorServlet // default handler ) /** Start the server and return the bound port. */ @@ -99,7 +98,7 @@ private[spark] class StandaloneRestServer( server.setThreadPool(threadPool) val mainHandler = new ServletContextHandler mainHandler.setContextPath("/") - servletToContext.foreach { case (servlet, prefix) => + contextToServlet.foreach { case (prefix, servlet) => mainHandler.addServlet(new ServletHolder(servlet), prefix) } server.setHandler(mainHandler) @@ -113,7 +112,7 @@ private[spark] class StandaloneRestServer( } } -private object StandaloneRestServer { +private[rest] object StandaloneRestServer { val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION val SC_UNKNOWN_PROTOCOL_VERSION = 468 } @@ -121,20 +120,7 @@ private object StandaloneRestServer { /** * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. */ -private abstract class StandaloneRestServlet extends HttpServlet with Logging { - - /** Service a request. If an exception is thrown in the process, indicate server error. */ - protected override def service( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - try { - super.service(request, response) - } catch { - case e: Exception => - logError("Exception while handling request", e) - response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) - } - } +private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { /** * Serialize the given response message to JSON and send it through the response servlet. @@ -146,11 +132,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging { val message = validateResponse(responseMessage, responseServlet) responseServlet.setContentType("application/json") responseServlet.setCharacterEncoding("utf-8") - responseServlet.setStatus(HttpServletResponse.SC_OK) - val content = message.toJson.getBytes(Charsets.UTF_8) - val out = new DataOutputStream(responseServlet.getOutputStream) - out.write(content) - out.close() + responseServlet.getWriter.write(message.toJson) } /** @@ -186,6 +168,19 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging { e } + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + /** * Validate the response to ensure that it is correctly constructed. * @@ -209,7 +204,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging { /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) extends StandaloneRestServlet { /** @@ -219,18 +214,15 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) protected override def doPost( request: HttpServletRequest, response: HttpServletResponse): Unit = { - val submissionId = request.getPathInfo.stripPrefix("/") - val responseMessage = - if (submissionId.nonEmpty) { - handleKill(submissionId) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in kill request.") - } + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } sendResponse(responseMessage, response) } - private def handleKill(submissionId: String): KillSubmissionResponse = { + protected def handleKill(submissionId: String): KillSubmissionResponse = { val askTimeout = AkkaUtils.askTimeout(conf) val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) @@ -246,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) extends StandaloneRestServlet { /** @@ -256,18 +248,15 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) protected override def doGet( request: HttpServletRequest, response: HttpServletResponse): Unit = { - val submissionId = request.getPathInfo.stripPrefix("/") - val responseMessage = - if (submissionId.nonEmpty) { - handleStatus(submissionId) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in status request.") - } + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } sendResponse(responseMessage, response) } - private def handleStatus(submissionId: String): SubmissionStatusResponse = { + protected def handleStatus(submissionId: String): SubmissionStatusResponse = { val askTimeout = AkkaUtils.askTimeout(conf) val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) @@ -287,7 +276,7 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ -private class SubmitRequestServlet( +private[rest] class SubmitRequestServlet( masterActor: ActorRef, masterUrl: String, conf: SparkConf) @@ -313,7 +302,7 @@ private class SubmitRequestServlet( handleSubmit(requestMessageJson, requestMessage, responseServlet) } catch { // The client failed to provide a valid JSON, so this is not our fault - case e @ (_: JsonMappingException | _: SubmitRestProtocolException) => + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) handleError("Malformed request: " + formatException(e)) } @@ -413,7 +402,7 @@ private class ErrorServlet extends StandaloneRestServlet { request: HttpServletRequest, response: HttpServletResponse): Unit = { val path = request.getPathInfo - val parts = path.stripPrefix("/").split("/").toSeq + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList var versionMismatch = false var msg = parts match { @@ -423,10 +412,10 @@ private class ErrorServlet extends StandaloneRestServlet { case `serverVersion` :: Nil => // http://host:port/correct-version "Missing the /submissions prefix." - case `serverVersion` :: "submissions" :: Nil => - // http://host:port/correct-version/submissions + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* "Missing an action: please specify one of /create, /kill, or /status." - case unknownVersion :: _ => + case unknownVersion :: tail => // http://host:port/unknown-version/* versionMismatch = true s"Unknown protocol version '$unknownVersion'." diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index bf3f1e4fc7832..df36566bec4b1 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -202,6 +202,8 @@ class TaskMetrics extends Serializable { merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) merged.incRemoteBytesRead(depMetrics.remoteBytesRead) + merged.incLocalBytesRead(depMetrics.localBytesRead) + merged.incLocalReadTime(depMetrics.localReadTime) merged.incRecordsRead(depMetrics.recordsRead) } _shuffleReadMetrics = Some(merged) @@ -343,6 +345,25 @@ class ShuffleReadMetrics extends Serializable { private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + /** + * Time the task spent (in milliseconds) reading local shuffle blocks (from the local disk). + */ + private var _localReadTime: Long = _ + def localReadTime = _localReadTime + private[spark] def incLocalReadTime(value: Long) = _localReadTime += value + + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + private var _localBytesRead: Long = _ + def localBytesRead = _localBytesRead + private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead = _remoteBytesRead + _localBytesRead + /** * Number of blocks fetched in this shuffle by this task (remote or local) */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3bb54855bae44..f9fc8aa30454e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -169,7 +169,9 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + + " LOCAL_READ_TIME=" + metrics.localReadTime + + " LOCAL_BYTES_READ=" + metrics.localBytesRead case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 584f4e7789d1a..d9c3a10dc5413 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -40,21 +40,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param version Spark version that generated the events. + * @param sourceName Filename (or other source identifier) from whence @logData is being read */ - def replay(logData: InputStream, version: String) { + def replay(logData: InputStream, version: String, sourceName: String) { var currentLine: String = null + var lineNumber: Int = 1 try { val lines = Source.fromInputStream(logData).getLines() lines.foreach { line => currentLine = line postToAll(JsonProtocol.sparkEventFromJson(parse(line))) + lineNumber += 1 } } catch { case ioe: IOException => throw ioe case e: Exception => - logError("Exception in parsing Spark event log.", e) - logError("Malformed line: %s\n".format(currentLine)) + logError(s"Exception parsing Spark event log: $sourceName", e) + logError(s"Malformed line #$lineNumber: $currentLine\n") } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ab9ee4f0096bf..2ebb79989da43 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -228,12 +228,14 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + val startTime = System.currentTimeMillis val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { @@ -244,6 +246,7 @@ final class ShuffleBlockFetcherIterator( return } } + shuffleMetrics.incLocalReadTime(System.currentTimeMillis - startTime) } private[this] def initialize(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 3a15e603b1969..cae6870c2ab20 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -37,8 +37,12 @@ private[spark] object ToolTips { "Bytes and records written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = - """Bytes and records read from remote executors. Typically less than shuffle write bytes - because this does not include shuffle data read locally.""" + """Total shuffle bytes and records read (includes both data read locally and data read from + remote executors). """ + + val SHUFFLE_READ_REMOTE_SIZE = + """Total shuffle bytes read from remote executors. This is a subset of the shuffle + read bytes; the remaining shuffle data is read locally. """ val GETTING_RESULT_TIME = """Time that the driver spends fetching task results from workers. If this is large, consider diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f463f8d7c7215..0b6fe70bd2062 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -401,9 +401,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta val shuffleReadDelta = - (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L)) - stageData.shuffleReadBytes += shuffleReadDelta + (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L)) + stageData.shuffleReadTotalBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta val shuffleReadRecordsDelta = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 05ffd5bc58fbb..d752434ad58ae 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -85,7 +85,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {if (stageData.hasShuffleRead) {
  • Shuffle read: - {s"${Utils.bytesToString(stageData.shuffleReadBytes)} / " + + {s"${Utils.bytesToString(stageData.shuffleReadTotalBytes)} / " + s"${stageData.shuffleReadRecords}"}
  • }} @@ -143,6 +143,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Shuffle Read Blocked Time +
  • + + + Shuffle Remote Reads + +
  • }}
  • metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble } - val shuffleReadBlockedQuantiles = Shuffle Read Blocked Time +: + val shuffleReadBlockedQuantiles = + + + Shuffle Read Blocked Time + + +: getFormattedTimeQuantiles(shuffleReadBlockedTimes) - val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble } - - val shuffleReadRecords = validTasks.map { case TaskUIData(_, metrics, _) => + val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble } + val shuffleReadTotalQuantiles = + + + Shuffle Read Size / Records + + +: + getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - val shuffleReadQuantiles = Shuffle Read Size / Records (Remote) +: - getFormattedSizeQuantilesWithRecords(shuffleReadSizes, shuffleReadRecords) + val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + } + val shuffleReadRemoteQuantiles = + + + Shuffle Remote Reads + + +: + getFormattedSizeQuantiles(shuffleReadRemoteSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble @@ -374,7 +404,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {shuffleReadBlockedQuantiles} - {shuffleReadQuantiles} + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + } else { Nil }, @@ -454,11 +487,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val shuffleReadBlockedTimeReadable = maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - val shuffleReadSortable = maybeShuffleRead.map(_.remoteBytesRead.toString).getOrElse("") - val shuffleReadReadable = maybeShuffleRead - .map(m => s"${Utils.bytesToString(m.remoteBytesRead)}").getOrElse("") + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") val shuffleWriteReadable = maybeShuffleWrite @@ -536,6 +573,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {s"$shuffleReadReadable / $shuffleReadRecords"} + + {shuffleReadRemoteReadable} + }} {if (hasShuffleWrite) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 703d43f9c640d..5865850fa09b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -138,7 +138,7 @@ private[ui] class StageTableBase( val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" val outputWrite = stageData.outputBytes val outputWriteWithUnit = if (outputWrite > 0) Utils.bytesToString(outputWrite) else "" - val shuffleRead = stageData.shuffleReadBytes + val shuffleRead = stageData.shuffleReadTotalBytes val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 37cf2c207ba40..9bf67db8acde1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -28,6 +28,7 @@ private[spark] object TaskDetailsClassNames { val SCHEDULER_DELAY = "scheduler_delay" val TASK_DESERIALIZATION_TIME = "deserialization_time" val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time" + val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 69aac6c862de5..dbf1ceeda1878 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -80,7 +80,7 @@ private[jobs] object UIData { var inputRecords: Long = _ var outputBytes: Long = _ var outputRecords: Long = _ - var shuffleReadBytes: Long = _ + var shuffleReadTotalBytes: Long = _ var shuffleReadRecords : Long = _ var shuffleWriteBytes: Long = _ var shuffleWriteRecords: Long = _ @@ -96,7 +96,7 @@ private[jobs] object UIData { def hasInput = inputBytes > 0 def hasOutput = outputBytes > 0 - def hasShuffleRead = shuffleReadBytes > 0 + def hasShuffleRead = shuffleReadTotalBytes > 0 def hasShuffleWrite = shuffleWriteBytes > 0 def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b0b545640f5aa..58d37e2d667f7 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -294,6 +294,8 @@ private[spark] object JsonProtocol { ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~ + ("Local Read Time" -> shuffleReadMetrics.localReadTime) ~ + ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~ ("Total Records Read" -> shuffleReadMetrics.recordsRead) } @@ -674,6 +676,8 @@ private[spark] object JsonProtocol { metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int]) metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long]) metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long]) + metrics.incLocalReadTime((json \ "Local Read Time").extractOpt[Long].getOrElse(0)) + metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0)) metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0)) metrics } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6af8dd555f2aa..c06bd6fab0cc9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -283,13 +283,6 @@ private[spark] object Utils extends Logging { dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) if (dir.exists() || !dir.mkdirs()) { dir = null - } else { - // Restrict file permissions via chmod if available. - // For Windows this step is ignored. - if (!isWindows && !chmod700(dir)) { - dir.delete() - dir = null - } } } catch { case e: SecurityException => dir = null; } } @@ -703,7 +696,9 @@ private[spark] object Utils extends Logging { try { val rootDir = new File(root) if (rootDir.exists || rootDir.mkdirs()) { - Some(createDirectory(root).getAbsolutePath()) + val dir = createDirectory(root) + chmod700(dir) + Some(dir.getAbsolutePath) } else { logError(s"Failed to create dir in $root. Ignoring this directory.") None diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index eaec5a71e6819..d69f2d9048055 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -723,6 +723,7 @@ private[spark] class ExternalSorter[K, V, C]( partitionWriters.foreach(_.commitAndClose()) var out: FileOutputStream = null var in: FileInputStream = null + val writeStartTime = System.nanoTime try { out = new FileOutputStream(outputFile, true) for (i <- 0 until numPartitions) { @@ -739,6 +740,8 @@ private[spark] class ExternalSorter[K, V, C]( if (in != null) { in.close() } + context.taskMetrics.shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } } else { // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d7d9dc7b50f30..4b25c200a695a 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark +import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.RDD import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { - var sc : SparkContext = _ +class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter + with MockitoSugar { + var blockManager: BlockManager = _ var cacheManager: CacheManager = _ var split: Partition = _ @@ -57,10 +59,6 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar }.cache() } - after { - sc.stop() - } - test("get uncached rdd") { // Do not mock this test, because attempting to match Array[Any], which is not covariant, // in blockManager.put is a losing battle. You have been warned. @@ -75,29 +73,21 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } test("get cached rdd") { - expecting { - val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) - blockManager.get(RDDBlockId(0, 0)).andReturn(Some(result)) - } + val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(5, 6, 7)) - } + val context = new TaskContextImpl(0, 0, 0, 0) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(5, 6, 7)) } test("get uncached local rdd") { - expecting { - // Local computation should not persist the resulting value, so don't expect a put(). - blockManager.get(RDDBlockId(0, 0)).andReturn(None) - } + // Local computation should not persist the resulting value, so don't expect a put(). + when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, 0, true) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } + val context = new TaskContextImpl(0, 0, 0, 0, true) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 29aed89b67aa7..a345e06ecb7d2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -17,141 +17,412 @@ package org.apache.spark.deploy.rest -import java.io.{File, FileInputStream, FileOutputStream, PrintWriter} -import java.util.jar.{JarEntry, JarOutputStream} -import java.util.zip.ZipEntry +import java.io.DataOutputStream +import java.net.{HttpURLConnection, URL} +import javax.servlet.http.HttpServletResponse -import scala.collection.mutable.ArrayBuffer -import scala.io.Source +import scala.collection.mutable -import akka.actor.ActorSystem -import com.google.common.io.ByteStreams -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -import org.scalatest.exceptions.TestFailedException +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import com.google.common.base.Charsets +import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} -import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.deploy.worker.Worker +import org.apache.spark.deploy.master.DriverState._ /** - * End-to-end tests for the REST application submission protocol in standalone mode. + * Tests for the REST application submission protocol used in standalone cluster mode. */ -class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { - private val systemsToStop = new ArrayBuffer[ActorSystem] - private val masterRestUrl = startLocalCluster() +class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { private val client = new StandaloneRestClient - private val mainJar = StandaloneRestSubmitSuite.createJar() - private val mainClass = StandaloneRestApp.getClass.getName.stripSuffix("$") + private var actorSystem: Option[ActorSystem] = None + private var server: Option[StandaloneRestServer] = None - override def afterAll() { - systemsToStop.foreach(_.shutdown()) + override def afterEach() { + actorSystem.foreach(_.shutdown()) + server.foreach(_.stop()) } - test("simple submit until completion") { - val resultsFile = File.createTempFile("test-submit", ".txt") - val numbers = Seq(1, 2, 3) - val size = 500 - val submissionId = submitApplication(resultsFile, numbers, size) - waitUntilFinished(submissionId) - validateResult(resultsFile, numbers, size) + test("construct submit request") { + val appArgs = Array("one", "two", "three") + val sparkProperties = Map("spark.app.name" -> "pi") + val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX") + val request = client.constructSubmitRequest( + "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables) + assert(request.action === Utils.getFormattedClassName(request)) + assert(request.clientSparkVersion === SPARK_VERSION) + assert(request.appResource === "my-app-resource") + assert(request.mainClass === "my-main-class") + assert(request.appArgs === appArgs) + assert(request.sparkProperties === sparkProperties) + assert(request.environmentVariables === environmentVariables) } - test("kill empty submission") { - val response = client.killSubmission(masterRestUrl, "submission-that-does-not-exist") - val killResponse = getKillResponse(response) - val killSuccess = killResponse.success - assert(!killSuccess) + test("create submission") { + val submittedDriverId = "my-driver-id" + val submitMessage = "your driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val appArgs = Array("one", "two", "four") + val request = constructSubmitRequest(masterUrl, appArgs) + assert(request.appArgs === appArgs) + assert(request.sparkProperties("spark.master") === masterUrl) + val response = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) + } + + test("create submission from main method") { + val submittedDriverId = "your-driver-id" + val submitMessage = "my driver is submitted" + val masterUrl = startDummyServer(submitId = submittedDriverId, submitMessage = submitMessage) + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.master", masterUrl) + conf.set("spark.app.name", "dreamer") + val appArgs = Array("one", "two", "six") + // main method calls this + val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val submitResponse = getSubmitResponse(response) + assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) + assert(submitResponse.serverSparkVersion === SPARK_VERSION) + assert(submitResponse.message === submitMessage) + assert(submitResponse.submissionId === submittedDriverId) + assert(submitResponse.success) } - test("kill running submission") { - val resultsFile = File.createTempFile("test-kill", ".txt") - val numbers = Seq(1, 2, 3) - val size = 500 - val submissionId = submitApplication(resultsFile, numbers, size) - val response = client.killSubmission(masterRestUrl, submissionId) + test("kill submission") { + val submissionId = "my-lyft-driver" + val killMessage = "your driver is killed" + val masterUrl = startDummyServer(killMessage = killMessage) + val response = client.killSubmission(masterUrl, submissionId) val killResponse = getKillResponse(response) - val killSuccess = killResponse.success - waitUntilFinished(submissionId) - val response2 = client.requestSubmissionStatus(masterRestUrl, submissionId) - val statusResponse = getStatusResponse(response2) - val statusSuccess = statusResponse.success - val driverState = statusResponse.driverState - assert(killSuccess) - assert(statusSuccess) - assert(driverState === DriverState.KILLED.toString) - // we should not see the expected results because we killed the submission - intercept[TestFailedException] { validateResult(resultsFile, numbers, size) } + assert(killResponse.action === Utils.getFormattedClassName(killResponse)) + assert(killResponse.serverSparkVersion === SPARK_VERSION) + assert(killResponse.message === killMessage) + assert(killResponse.submissionId === submissionId) + assert(killResponse.success) } - test("request status for empty submission") { - val response = client.requestSubmissionStatus(masterRestUrl, "submission-that-does-not-exist") + test("request submission status") { + val submissionId = "my-uber-driver" + val submissionState = KILLED + val submissionException = new Exception("there was an irresponsible mix of alcohol and cars") + val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException)) + val response = client.requestSubmissionStatus(masterUrl, submissionId) val statusResponse = getStatusResponse(response) - val statusSuccess = statusResponse.success - assert(!statusSuccess) + assert(statusResponse.action === Utils.getFormattedClassName(statusResponse)) + assert(statusResponse.serverSparkVersion === SPARK_VERSION) + assert(statusResponse.message.contains(submissionException.getMessage)) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === submissionState.toString) + assert(statusResponse.success) + } + + test("create then kill") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // kill submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.killSubmission(masterUrl, submissionId) + val killResponse = getKillResponse(response2) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId) + } + + test("create then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val submitResponse = getSubmitResponse(response1) + assert(submitResponse.success) + assert(submitResponse.submissionId != null) + // request status of submission that was just created + val submissionId = submitResponse.submissionId + val response2 = client.requestSubmissionStatus(masterUrl, submissionId) + val statusResponse = getStatusResponse(response2) + assert(statusResponse.success) + assert(statusResponse.submissionId === submissionId) + assert(statusResponse.driverState === RUNNING.toString) + } + + test("create then kill then request status") { + val masterUrl = startSmartServer() + val request = constructSubmitRequest(masterUrl) + val response1 = client.createSubmission(masterUrl, request) + val response2 = client.createSubmission(masterUrl, request) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(submitResponse1.success) + assert(submitResponse2.success) + assert(submitResponse1.submissionId != null) + assert(submitResponse2.submissionId != null) + val submissionId1 = submitResponse1.submissionId + val submissionId2 = submitResponse2.submissionId + // kill only submission 1, but not submission 2 + val response3 = client.killSubmission(masterUrl, submissionId1) + val killResponse = getKillResponse(response3) + assert(killResponse.success) + assert(killResponse.submissionId === submissionId1) + // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still + val response4 = client.requestSubmissionStatus(masterUrl, submissionId1) + val response5 = client.requestSubmissionStatus(masterUrl, submissionId2) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(statusResponse1.submissionId === submissionId1) + assert(statusResponse2.submissionId === submissionId2) + assert(statusResponse1.driverState === KILLED.toString) + assert(statusResponse2.driverState === RUNNING.toString) + } + + test("kill or request status before create") { + val masterUrl = startSmartServer() + val doesNotExist = "does-not-exist" + // kill a non-existent submission + val response1 = client.killSubmission(masterUrl, doesNotExist) + val killResponse = getKillResponse(response1) + assert(!killResponse.success) + assert(killResponse.submissionId === doesNotExist) + // request status for a non-existent submission + val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist) + val statusResponse = getStatusResponse(response2) + assert(!statusResponse.success) + assert(statusResponse.submissionId === doesNotExist) + } + + /* ---------------------------------------- * + | Aberrant client / server behavior | + * ---------------------------------------- */ + + test("good request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val json = constructSubmitRequest(masterUrl).toJson + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", json) + val (response2, code2) = sendHttpRequestWithResponse(s"$killRequestPath/anything", "POST") + val (response3, code3) = sendHttpRequestWithResponse(s"$killRequestPath/any/thing", "POST") + val (response4, code4) = sendHttpRequestWithResponse(s"$statusRequestPath/anything", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$statusRequestPath/any/thing", "GET") + // these should all succeed and the responses should be of the correct types + getSubmitResponse(response1) + val killResponse1 = getKillResponse(response2) + val killResponse2 = getKillResponse(response3) + val statusResponse1 = getStatusResponse(response4) + val statusResponse2 = getStatusResponse(response5) + assert(killResponse1.submissionId === "anything") + assert(killResponse2.submissionId === "any") + assert(statusResponse1.submissionId === "anything") + assert(statusResponse2.submissionId === "any") + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + assert(code3 === HttpServletResponse.SC_OK) + assert(code4 === HttpServletResponse.SC_OK) + assert(code5 === HttpServletResponse.SC_OK) + } + + test("good request paths, bad requests") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill" + val statusRequestPath = s"$httpUrl/$v/submissions/status" + val goodJson = constructSubmitRequest(masterUrl).toJson + val badJson1 = goodJson.replaceAll("action", "fraction") // invalid JSON + val badJson2 = goodJson.substring(goodJson.size / 2) // malformed JSON + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST") // missing JSON + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson1) + val (response3, code3) = sendHttpRequestWithResponse(submitRequestPath, "POST", badJson2) + val (response4, code4) = sendHttpRequestWithResponse(killRequestPath, "POST") // missing ID + val (response5, code5) = sendHttpRequestWithResponse(s"$killRequestPath/", "POST") + val (response6, code6) = sendHttpRequestWithResponse(statusRequestPath, "GET") // missing ID + val (response7, code7) = sendHttpRequestWithResponse(s"$statusRequestPath/", "GET") + // these should all fail as error responses + getErrorResponse(response1) + getErrorResponse(response2) + getErrorResponse(response3) + getErrorResponse(response4) + getErrorResponse(response5) + getErrorResponse(response6) + getErrorResponse(response7) + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + } + + test("bad request paths") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") + val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") + val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") + val (response4, code4) = sendHttpRequestWithResponse(s"$httpUrl/$v/", "GET") + val (response5, code5) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions", "GET") + val (response6, code6) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/", "GET") + val (response7, code7) = sendHttpRequestWithResponse(s"$httpUrl/$v/submissions/bad", "GET") + val (response8, code8) = sendHttpRequestWithResponse(s"$httpUrl/bad-version", "GET") + assert(code1 === HttpServletResponse.SC_BAD_REQUEST) + assert(code2 === HttpServletResponse.SC_BAD_REQUEST) + assert(code3 === HttpServletResponse.SC_BAD_REQUEST) + assert(code4 === HttpServletResponse.SC_BAD_REQUEST) + assert(code5 === HttpServletResponse.SC_BAD_REQUEST) + assert(code6 === HttpServletResponse.SC_BAD_REQUEST) + assert(code7 === HttpServletResponse.SC_BAD_REQUEST) + assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + // all responses should be error responses + val errorResponse1 = getErrorResponse(response1) + val errorResponse2 = getErrorResponse(response2) + val errorResponse3 = getErrorResponse(response3) + val errorResponse4 = getErrorResponse(response4) + val errorResponse5 = getErrorResponse(response5) + val errorResponse6 = getErrorResponse(response6) + val errorResponse7 = getErrorResponse(response7) + val errorResponse8 = getErrorResponse(response8) + // only the incompatible version response should have server protocol version set + assert(errorResponse1.highestProtocolVersion === null) + assert(errorResponse2.highestProtocolVersion === null) + assert(errorResponse3.highestProtocolVersion === null) + assert(errorResponse4.highestProtocolVersion === null) + assert(errorResponse5.highestProtocolVersion === null) + assert(errorResponse6.highestProtocolVersion === null) + assert(errorResponse7.highestProtocolVersion === null) + assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + } + + test("server returns unknown fields") { + val masterUrl = startSmartServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val oldJson = constructSubmitRequest(masterUrl).toJson + val oldFields = parse(oldJson).asInstanceOf[JObject].obj + val newFields = oldFields ++ Seq( + JField("tomato", JString("not-a-fruit")), + JField("potato", JString("not-po-tah-to")) + ) + val newJson = pretty(render(JObject(newFields))) + // send two requests, one with the unknown fields and the other without + val (response1, code1) = sendHttpRequestWithResponse(submitRequestPath, "POST", oldJson) + val (response2, code2) = sendHttpRequestWithResponse(submitRequestPath, "POST", newJson) + val submitResponse1 = getSubmitResponse(response1) + val submitResponse2 = getSubmitResponse(response2) + assert(code1 === HttpServletResponse.SC_OK) + assert(code2 === HttpServletResponse.SC_OK) + // only the response to the modified request should have unknown fields set + assert(submitResponse1.unknownFields === null) + assert(submitResponse2.unknownFields === Array("tomato", "potato")) + } + + test("client handles faulty server") { + val masterUrl = startFaultyServer() + val httpUrl = masterUrl.replace("spark://", "http://") + val v = StandaloneRestServer.PROTOCOL_VERSION + val submitRequestPath = s"$httpUrl/$v/submissions/create" + val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" + val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" + val json = constructSubmitRequest(masterUrl).toJson + // server returns malformed response unwittingly + // client should throw an appropriate exception to indicate server failure + val conn1 = sendHttpRequest(submitRequestPath, "POST", json) + intercept[SubmitRestProtocolException] { client.readResponse(conn1) } + // server attempts to send invalid response, but fails internally on validation + // client should receive an error response as server is able to recover + val conn2 = sendHttpRequest(killRequestPath, "POST") + val response2 = client.readResponse(conn2) + getErrorResponse(response2) + assert(conn2.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + // server explodes internally beyond recovery + // client should throw an appropriate exception to indicate server failure + val conn3 = sendHttpRequest(statusRequestPath, "GET") + intercept[SubmitRestProtocolException] { client.readResponse(conn3) } // empty response + assert(conn3.getResponseCode === HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + } + + /* --------------------- * + | Helper methods | + * --------------------- */ + + /** Start a dummy server that responds to requests using the specified parameters. */ + private def startDummyServer( + submitId: String = "fake-driver-id", + submitMessage: String = "driver is submitted", + killMessage: String = "driver is killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None): String = { + startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + } + + /** Start a smarter dummy server that keeps track of submitted driver states. */ + private def startSmartServer(): String = { + startServer(new SmarterMaster) + } + + /** Start a dummy server that is faulty in many ways... */ + private def startFaultyServer(): String = { + startServer(new DummyMaster, faulty = true) } /** - * Start a local cluster containing one Master and a few Workers. - * Do not use [[org.apache.spark.deploy.LocalSparkCluster]] here because we want the REST URL. - * Return the Master's REST URL to which applications should be submitted. + * Start a [[StandaloneRestServer]] that communicates with the given actor. + * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. + * Return the master URL that corresponds to the address of this server. */ - private def startLocalCluster(): String = { - val conf = new SparkConf(false) - .set("spark.master.rest.enabled", "true") - .set("spark.master.rest.port", "0") - val (numWorkers, coresPerWorker, memPerWorker) = (2, 1, 512) - val localHostName = Utils.localHostName() - val (masterSystem, masterPort, _, _masterRestPort) = - Master.startSystemAndActor(localHostName, 0, 0, conf) - val masterRestPort = _masterRestPort.getOrElse { fail("REST server not started on Master!") } - val masterUrl = "spark://" + localHostName + ":" + masterPort - val masterRestUrl = "spark://" + localHostName + ":" + masterRestPort - (1 to numWorkers).foreach { n => - val (workerSystem, _) = Worker.startSystemAndActor( - localHostName, 0, 0, coresPerWorker, memPerWorker, Array(masterUrl), null, Some(n)) - systemsToStop.append(workerSystem) - } - systemsToStop.append(masterSystem) - masterRestUrl + private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + val name = "test-standalone-rest-protocol" + val conf = new SparkConf + val localhost = Utils.localHostName() + val securityManager = new SecurityManager(conf) + val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _server = + if (faulty) { + new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } else { + new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + } + val port = _server.start() + // set these to clean them up after every test + actorSystem = Some(_actorSystem) + server = Some(_server) + s"spark://$localhost:$port" } - /** Submit the [[StandaloneRestApp]] and return the corresponding submission ID. */ - private def submitApplication(resultsFile: File, numbers: Seq[Int], size: Int): String = { - val appArgs = Seq(resultsFile.getAbsolutePath) ++ numbers.map(_.toString) ++ Seq(size.toString) + /** Create a submit request with real parameters using Spark submit. */ + private def constructSubmitRequest( + masterUrl: String, + appArgs: Array[String] = Array.empty): CreateSubmissionRequest = { + val mainClass = "main-class-not-used" + val mainJar = "dummy-jar-not-used.jar" val commandLineArgs = Array( "--deploy-mode", "cluster", - "--master", masterRestUrl, + "--master", masterUrl, "--name", mainClass, "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) - val request = client.constructSubmitRequest( - mainJar, mainClass, appArgs.toArray, sparkProperties.toMap, Map.empty) - val response = client.createSubmission(masterRestUrl, request) - val submitResponse = getSubmitResponse(response) - val submissionId = submitResponse.submissionId - assert(submissionId != null, "Application submission was unsuccessful!") - submissionId - } - - /** Wait until the given submission has finished running up to the specified timeout. */ - private def waitUntilFinished(submissionId: String, maxSeconds: Int = 30): Unit = { - var finished = false - val expireTime = System.currentTimeMillis + maxSeconds * 1000 - while (!finished) { - val response = client.requestSubmissionStatus(masterRestUrl, submissionId) - val statusResponse = getStatusResponse(response) - val driverState = statusResponse.driverState - finished = - driverState != DriverState.SUBMITTED.toString && - driverState != DriverState.RUNNING.toString - if (System.currentTimeMillis > expireTime) { - fail(s"Driver $submissionId did not finish within $maxSeconds seconds.") - } - } + client.constructSubmitRequest( + mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) } /** Return the response as a submit response, or fail with error otherwise. */ @@ -181,85 +452,151 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef } } - /** Validate whether the application produced the corrupt output. */ - private def validateResult(resultsFile: File, numbers: Seq[Int], size: Int): Unit = { - val lines = Source.fromFile(resultsFile.getAbsolutePath).getLines().toSeq - val unexpectedContent = - if (lines.nonEmpty) { - "[\n" + lines.map { l => " " + l }.mkString("\n") + "\n]" - } else { - "[EMPTY]" - } - assert(lines.size === 2, s"Unexpected content in file: $unexpectedContent") - assert(lines(0).toInt === numbers.sum, s"Sum of ${numbers.mkString(",")} is incorrect") - assert(lines(1).toInt === (size / 2) + 1, "Result of Spark job is incorrect") + /** Return the response as an error response, or fail if the response was not an error. */ + private def getErrorResponse(response: SubmitRestProtocolResponse): ErrorResponse = { + response match { + case e: ErrorResponse => e + case r => fail(s"Expected error response. Actual: ${r.toJson}") + } } -} - -private object StandaloneRestSubmitSuite { - private val pathPrefix = this.getClass.getPackage.getName.replaceAll("\\.", "/") /** - * Create a jar that contains all the class files needed for running the [[StandaloneRestApp]]. - * Return the absolute path to that jar. + * Send an HTTP request to the given URL using the method and the body specified. + * Return the connection object. */ - def createJar(): String = { - val jarFile = File.createTempFile("test-standalone-rest-protocol", ".jar") - val jarFileStream = new FileOutputStream(jarFile) - val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest) - jarStream.putNextEntry(new ZipEntry(pathPrefix)) - getClassFiles.foreach { cf => - jarStream.putNextEntry(new JarEntry(pathPrefix + "/" + cf.getName)) - val in = new FileInputStream(cf) - ByteStreams.copy(in, jarStream) - in.close() + private def sendHttpRequest( + url: String, + method: String, + body: String = ""): HttpURLConnection = { + val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod(method) + if (body.nonEmpty) { + conn.setDoOutput(true) + val out = new DataOutputStream(conn.getOutputStream) + out.write(body.getBytes(Charsets.UTF_8)) + out.close() } - jarStream.close() - jarFileStream.close() - jarFile.getAbsolutePath + conn } /** - * Return a list of class files compiled for [[StandaloneRestApp]]. - * This includes all the anonymous classes used in the application. + * Send an HTTP request to the given URL using the method and the body specified. + * Return a 2-tuple of the response message from the server and the response code. */ - private def getClassFiles: Seq[File] = { - val className = Utils.getFormattedClassName(StandaloneRestApp) - val clazz = StandaloneRestApp.getClass - val basePath = clazz.getProtectionDomain.getCodeSource.getLocation.toURI.getPath - val baseDir = new File(basePath + "/" + pathPrefix) - baseDir.listFiles().filter(_.getName.contains(className)) + private def sendHttpRequestWithResponse( + url: String, + method: String, + body: String = ""): (SubmitRestProtocolResponse, Int) = { + val conn = sendHttpRequest(url, method, body) + (client.readResponse(conn), conn.getResponseCode) } } /** - * Sample application to be submitted to the cluster using the REST gateway. - * All relevant classes will be packaged into a jar at run time. + * A mock standalone Master that responds with dummy messages. + * In all responses, the success parameter is always true. */ -object StandaloneRestApp { - // Usage: [path to results file] [num1] [num2] [num3] [rddSize] - // The first line of the results file should be (num1 + num2 + num3) - // The second line should be (rddSize / 2) + 1 - def main(args: Array[String]) { - assert(args.size == 5, s"Expected exactly 5 arguments: ${args.mkString(",")}") - val resultFile = new File(args(0)) - val writer = new PrintWriter(resultFile) - try { - val conf = new SparkConf() - val sc = new SparkContext(conf) - val firstLine = args(1).toInt + args(2).toInt + args(3).toInt - val secondLine = sc.parallelize(1 to args(4).toInt) - .map { i => (i / 2, i) } - .reduceByKey(_ + _) - .count() - writer.println(firstLine) - writer.println(secondLine) - } catch { - case e: Exception => - writer.println(e) - e.getStackTrace.foreach { l => writer.println(" " + l) } - } finally { - writer.close() +private class DummyMaster( + submitId: String = "fake-driver-id", + submitMessage: String = "submitted", + killMessage: String = "killed", + state: DriverState = FINISHED, + exception: Option[Exception] = None) + extends Actor { + + override def receive = { + case RequestSubmitDriver(driverDesc) => + sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + case RequestKillDriver(driverId) => + sender ! KillDriverResponse(driverId, success = true, killMessage) + case RequestDriverStatus(driverId) => + sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + } +} + +/** + * A mock standalone Master that keeps track of drivers that have been submitted. + * + * If a driver is submitted, its state is immediately set to RUNNING. + * If an existing driver is killed, its state is immediately set to KILLED. + * If an existing driver's status is requested, its state is returned in the response. + * Submits are always successful while kills and status requests are successful only + * if the driver was submitted in the past. + */ +private class SmarterMaster extends Actor { + private var counter: Int = 0 + private val submittedDrivers = new mutable.HashMap[String, DriverState] + + override def receive = { + case RequestSubmitDriver(driverDesc) => + val driverId = s"driver-$counter" + submittedDrivers(driverId) = RUNNING + counter += 1 + sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + + case RequestKillDriver(driverId) => + val success = submittedDrivers.contains(driverId) + if (success) { + submittedDrivers(driverId) = KILLED + } + sender ! KillDriverResponse(driverId, success, "killed") + + case RequestDriverStatus(driverId) => + val found = submittedDrivers.contains(driverId) + val state = submittedDrivers.get(driverId) + sender ! DriverStatusResponse(found, state, None, None, None) + } +} + +/** + * A [[StandaloneRestServer]] that is faulty in many ways. + * + * When handling a submit request, the server returns a malformed JSON. + * When handling a kill request, the server returns an invalid JSON. + * When handling a status request, the server throws an internal exception. + * The purpose of this class is to test that client handles these cases gracefully. + */ +private class FaultyStandaloneRestServer( + host: String, + requestedPort: Int, + masterActor: ActorRef, + masterUrl: String, + masterConf: SparkConf) + extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + + protected override val contextToServlet = Map[String, StandaloneRestServlet]( + s"$baseContext/create/*" -> new MalformedSubmitServlet, + s"$baseContext/kill/*" -> new InvalidKillServlet, + s"$baseContext/status/*" -> new ExplodingStatusServlet, + "/*" -> new ErrorServlet + ) + + /** A faulty servlet that produces malformed responses. */ + class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + protected override def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val badJson = responseMessage.toJson.drop(10).dropRight(20) + responseServlet.getWriter.write(badJson) + } + } + + /** A faulty servlet that produces invalid responses. */ + class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = super.handleKill(submissionId) + k.submissionId = null + k + } + } + + /** A faulty status servlet that explodes. */ + class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + private def explode: Int = 1 / 0 + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val s = super.handleStatus(submissionId) + s.workerId = explode.toString + s } } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 98b0a16ce88ba..2e58c159a2ed8 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.Text -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} @@ -42,7 +42,15 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { private var factory: CompressionCodecFactory = _ override def beforeAll() { - sc = new SparkContext("local", "test") + // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which + // can cause Filesystem.get(Configuration) to return a cached instance created with a different + // configuration than the one passed to get() (see HADOOP-8490 for more details). This caused + // hard-to-reproduce test failures, since any suites that were run after this one would inherit + // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, + // we disable FileSystem caching in this suite. + val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") + + sc = new SparkContext("local", "test", conf) // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 7e360cc6082ec..702c4cb3bdef9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -61,7 +61,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, SPARK_VERSION) + replayer.replay(logData, SPARK_VERSION, logFilePath.toString) } finally { logData.close() } @@ -120,7 +120,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { try { val replayer = new ReplayListenerBus() replayer.addListener(eventMonster) - replayer.replay(logData, version) + replayer.replay(logData, version, eventLog.getPath().toString) } finally { logData.close() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index 46ab02bfef780..afbaa9ade811f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -17,45 +17,48 @@ package org.apache.spark.scheduler.mesos -import org.apache.spark.executor.MesosExecutorBackend -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus, - TaskDescription, WorkerOffer, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} -import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _} -import org.apache.mesos.Protos.Value.Scalar -import org.easymock.{Capture, EasyMock} import java.nio.ByteBuffer -import java.util.Collections import java.util -import org.scalatest.mock.EasyMockSugar +import java.util.Collections import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { +import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import org.mockito.{ArgumentCaptor, Matchers} +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext} +import org.apache.spark.executor.MesosExecutorBackend +import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, + TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.scheduler.cluster.mesos.{MesosSchedulerBackend, MemoryUtils} + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar { test("check spark-class location correctly") { val conf = new SparkConf conf.set("spark.mesos.executor.home" , "/mesos-home") - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - EasyMock.replay(listenerBus) - - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/spark-home")).anyTimes() - EasyMock.expect(sc.conf).andReturn(conf).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + + when(sc.conf).thenReturn(conf) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.executorMemory).thenReturn(100) + when(sc.listenerBus).thenReturn(listenerBus) + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") @@ -84,20 +87,19 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea .setSlaveId(SlaveID.newBuilder().setValue(s"s${id.toString}")).setHostname(s"host${id.toString}").build() } - val driver = EasyMock.createMock(classOf[SchedulerDriver]) - val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + val driver = mock[SchedulerDriver] + val taskScheduler = mock[TaskSchedulerImpl] - val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - EasyMock.replay(listenerBus) + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) - val sc = EasyMock.createMock(classOf[SparkContext]) - EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() - EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() - EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() - EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() - EasyMock.expect(sc.listenerBus).andReturn(listenerBus) - EasyMock.replay(sc) + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/path")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(new SparkConf) + when(sc.listenerBus).thenReturn(listenerBus) val minMem = MemoryUtils.calculateTotalMemory(sc).toInt val minCpu = 4 @@ -121,25 +123,29 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea 2 )) val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) + when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = new Capture[util.Collection[TaskInfo]] - EasyMock.expect( + val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + when( driver.launchTasks( - EasyMock.eq(Collections.singleton(mesosOffers.get(0).getId)), - EasyMock.capture(capture), - EasyMock.anyObject(classOf[Filters]) + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) ) - ).andReturn(Status.valueOf(1)).once - EasyMock.expect(driver.declineOffer(mesosOffers.get(1).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.expect(driver.declineOffer(mesosOffers.get(2).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + ).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(1).getId)).thenReturn(Status.valueOf(1)) + when(driver.declineOffer(mesosOffers.get(2).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers) - EasyMock.verify(driver) + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), + capture.capture(), + any(classOf[Filters]) + ) + verify(driver, times(1)).declineOffer(mesosOffers.get(1).getId) + verify(driver, times(1)).declineOffer(mesosOffers.get(2).getId) assert(capture.getValue.size() == 1) val taskInfo = capture.getValue.iterator().next() assert(taskInfo.getName.equals("n1")) @@ -151,15 +157,13 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea // Unwanted resources offered on an existing node. Make sure they are declined val mesosOffers2 = new java.util.ArrayList[Offer] mesosOffers2.add(createOffer(1, minMem, minCpu)) - EasyMock.reset(taskScheduler) - EasyMock.reset(driver) - EasyMock.expect(taskScheduler.resourceOffers(EasyMock.anyObject(classOf[Seq[WorkerOffer]])).andReturn(Seq(Seq()))) - EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() - EasyMock.replay(taskScheduler) - EasyMock.expect(driver.declineOffer(mesosOffers2.get(0).getId)).andReturn(Status.valueOf(1)).times(1) - EasyMock.replay(driver) + reset(taskScheduler) + reset(driver) + when(taskScheduler.resourceOffers(any(classOf[Seq[WorkerOffer]]))).thenReturn(Seq(Seq())) + when(taskScheduler.CPUS_PER_TASK).thenReturn(2) + when(driver.declineOffer(mesosOffers2.get(0).getId)).thenReturn(Status.valueOf(1)) backend.resourceOffers(driver, mesosOffers2) - EasyMock.verify(driver) + verify(driver, times(1)).declineOffer(mesosOffers2.get(0).getId) } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e8405baa8e3ea..6019282d2fb70 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -227,6 +227,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.incRemoteBytesRead(base + 1) + shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) shuffleWriteMetrics.incShuffleBytesWritten(base + 3) taskMetrics.setExecutorRunTime(base + 4) @@ -260,8 +261,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 102) - assert(stage1Data.shuffleReadBytes == 201) + assert(stage0Data.shuffleReadTotalBytes == 220) + assert(stage1Data.shuffleReadTotalBytes == 410) assert(stage0Data.shuffleWriteBytes == 106) assert(stage1Data.shuffleWriteBytes == 203) assert(stage0Data.executorRunTime == 108) @@ -290,8 +291,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc stage0Data = listener.stageIdToData.get((0, 0)).get stage1Data = listener.stageIdToData.get((1, 0)).get - assert(stage0Data.shuffleReadBytes == 402) - assert(stage1Data.shuffleReadBytes == 602) + // Task 1235 contributed (100+1)+(100+9) = 210 shuffle bytes, and task 1234 contributed + // (300+1)+(300+9) = 610 total shuffle bytes, so the total for the stage is 820. + assert(stage0Data.shuffleReadTotalBytes == 820) + // Task 1236 contributed 410 shuffle bytes, and task 1237 contributed 810 shuffle bytes. + assert(stage1Data.shuffleReadTotalBytes == 1220) assert(stage0Data.shuffleWriteBytes == 406) assert(stage1Data.shuffleWriteBytes == 606) assert(stage0Data.executorRunTime == 408) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f3017dc42cd5c..c181baf6844b0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -260,6 +260,19 @@ class JsonProtocolSuite extends FunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } + test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { + // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, + hasHadoopInput = false, hasOutput = false, hasRecords = false) + assert(metrics.shuffleReadMetrics.nonEmpty) + val newJson = JsonProtocol.taskMetricsToJson(metrics) + val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } + .removeField { case (field, _) => field == "Local Read Time" } + val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) + assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) + assert(newMetrics.shuffleReadMetrics.get.localReadTime == 0) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -695,6 +708,8 @@ class JsonProtocolSuite extends FunSuite { sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1) + sr.incLocalReadTime(a + e) + sr.incLocalBytesRead(a + f) t.setShuffleReadMetrics(Some(sr)) } if (hasOutput) { @@ -941,6 +956,8 @@ class JsonProtocolSuite extends FunSuite { | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, | "Remote Bytes Read": 1000, + | "Local Read Time": 1000, + | "Local Bytes Read": 1100, | "Total Records Read" : 10 | }, | "Shuffle Write Metrics": { diff --git a/data/mllib/sample_isotonic_regression_data.txt b/data/mllib/sample_isotonic_regression_data.txt new file mode 100644 index 0000000000000..d257b509d4d37 --- /dev/null +++ b/data/mllib/sample_isotonic_regression_data.txt @@ -0,0 +1,100 @@ +0.24579296,0.01 +0.28505864,0.02 +0.31208567,0.03 +0.35900051,0.04 +0.35747068,0.05 +0.16675166,0.06 +0.17491076,0.07 +0.04181540,0.08 +0.04793473,0.09 +0.03926568,0.10 +0.12952575,0.11 +0.00000000,0.12 +0.01376849,0.13 +0.13105558,0.14 +0.08873024,0.15 +0.12595614,0.16 +0.15247323,0.17 +0.25956145,0.18 +0.20040796,0.19 +0.19581846,0.20 +0.15757267,0.21 +0.13717491,0.22 +0.19020908,0.23 +0.19581846,0.24 +0.20091790,0.25 +0.16879143,0.26 +0.18510964,0.27 +0.20040796,0.28 +0.29576747,0.29 +0.43396226,0.30 +0.53391127,0.31 +0.52116267,0.32 +0.48546660,0.33 +0.49209587,0.34 +0.54156043,0.35 +0.59765426,0.36 +0.56144824,0.37 +0.58592555,0.38 +0.52983172,0.39 +0.50178480,0.40 +0.52626211,0.41 +0.58286588,0.42 +0.64660887,0.43 +0.68077511,0.44 +0.74298827,0.45 +0.64864865,0.46 +0.67261601,0.47 +0.65782764,0.48 +0.69811321,0.49 +0.63029067,0.50 +0.61601224,0.51 +0.63233044,0.52 +0.65323814,0.53 +0.65323814,0.54 +0.67363590,0.55 +0.67006629,0.56 +0.51555329,0.57 +0.50892402,0.58 +0.33299337,0.59 +0.36206017,0.60 +0.43090260,0.61 +0.45996940,0.62 +0.56348802,0.63 +0.54920959,0.64 +0.48393677,0.65 +0.48495665,0.66 +0.46965834,0.67 +0.45181030,0.68 +0.45843957,0.69 +0.47118817,0.70 +0.51555329,0.71 +0.58031617,0.72 +0.55481897,0.73 +0.56297807,0.74 +0.56603774,0.75 +0.57929628,0.76 +0.64762876,0.77 +0.66241713,0.78 +0.69301377,0.79 +0.65119837,0.80 +0.68332483,0.81 +0.66598674,0.82 +0.73890872,0.83 +0.73992861,0.84 +0.84242733,0.85 +0.91330954,0.86 +0.88016318,0.87 +0.90719021,0.88 +0.93115757,0.89 +0.93115757,0.90 +0.91942886,0.91 +0.92911780,0.92 +0.95665477,0.93 +0.95002550,0.94 +0.96940337,0.95 +1.00000000,0.96 +0.89801122,0.97 +0.90311066,0.98 +0.90362060,0.99 +0.83477817,1.0 \ No newline at end of file diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index dfa924d2aa0ba..3062e9c3c6651 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -244,6 +244,8 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): versions = asf_jira.project_versions("SPARK") versions = sorted(versions, key=lambda x: x.name, reverse=True) versions = filter(lambda x: x.raw['released'] is False, versions) + # Consider only x.y.z versions + versions = filter(lambda x: re.match('\d+\.\d+\.\d+', x.name), versions) default_fix_versions = map(lambda x: fix_version_from_branch(x, versions).name, merge_branches) for v in default_fix_versions: diff --git a/docs/building-spark.md b/docs/building-spark.md index d3824fb61eef9..088da7da4980e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -159,16 +159,6 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the [wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup). -# Building Spark Debian Packages - -_NOTE: Debian packaging is deprecated and is scheduled to be removed in Spark 1.4._ - -The Maven build includes support for building a Debian package containing the assembly 'fat-jar', PySpark, and the necessary scripts and configuration files. This can be created by specifying the following: - - mvn -Pdeb -DskipTests clean package - -The debian package can then be found under assembly/target. We added the short commit hash to the file name so that we can distinguish individual packages built for SNAPSHOT versions. - # Running Java 8 Test Suites Running only Java 8 tests and nothing else. diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 719cc95767b00..5b9b4dd83b774 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -23,7 +23,7 @@ the supported algorithms for each type of problem. Multiclass Classificationdecision trees, naive Bayes - Regressionlinear least squares, Lasso, ridge regression, decision trees + Regressionlinear least squares, Lasso, ridge regression, decision trees, isotonic regression @@ -35,3 +35,4 @@ More details for these methods can be found here: * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) * [Decision trees](mllib-decision-tree.html) * [Naive Bayes](mllib-naive-bayes.html) +* [Isotonic regression](mllib-isotonic-regression.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 99ed6b60e3f00..09b56576699e0 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -4,28 +4,25 @@ title: Clustering - MLlib displayTitle: MLlib - Clustering --- -* Table of contents -{:toc} - - -## Clustering - Clustering is an unsupervised learning problem whereby we aim to group subsets of entities with one another based on some notion of similarity. Clustering is often used for exploratory analysis and/or as a component of a hierarchical supervised learning pipeline (in which distinct classifiers or regression -models are trained for each cluster). +models are trained for each cluster). MLlib supports the following models: -### k-means +* Table of contents +{:toc} + +## K-means [k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in MLlib has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -35,74 +32,9 @@ initialization via k-means\|\|. guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. -* *epsilon* determines the distance threshold within which we consider k-means to have converged. - -### Gaussian mixture - -A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) -represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, -each with its own probability. The MLlib implementation uses the -[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) - algorithm to induce the maximum-likelihood model given a set of samples. The implementation -has the following parameters: - -* *k* is the number of desired clusters. -* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved. -* *maxIterations* is the maximum number of iterations to perform without reaching convergence. -* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data. - -### Power Iteration Clustering - -Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm: - -* accepts a [Graph](api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. -* calculates the principal eigenvalue and eigenvector -* Clusters each of the input points according to their principal eigenvector component value - -Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf} - -Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation: - -

    - The Property Graph - -

    - -### Latent Dirichlet Allocation (LDA) - -[Latent Dirichlet Allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) -is a topic model which infers topics from a collection of text documents. -LDA can be thought of as a clustering algorithm as follows: - -* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. -* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. -* Rather than estimating a clustering using a traditional distance, LDA uses a function based - on a statistical model of how text documents are generated. - -LDA takes in a collection of documents as vectors of word counts. -It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) -on the likelihood function. After fitting on the documents, LDA provides: - -* Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. - -LDA takes the following parameters: - -* `k`: Number of topics (i.e., cluster centers) -* `maxIterations`: Limit on the number of iterations of EM used for learning -* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. -* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. -* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. - -*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet -support prediction on new documents, and it does not have a Python API. These will be added in the future. +* *epsilon* determines the distance threshold within which we consider k-means to have converged. -### Examples - -#### k-means +**Examples**
    @@ -216,13 +148,27 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
    -#### GaussianMixture +## Gaussian mixture + +A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) +represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, +each with its own probability. The MLlib implementation uses the +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) + algorithm to induce the maximum-likelihood model given a set of samples. The implementation +has the following parameters: + +* *k* is the number of desired clusters. +* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved. +* *maxIterations* is the maximum number of iterations to perform without reaching convergence. +* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data. + +**Examples**
    In the following example after loading and parsing data, we use a -[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) -object to cluster the data into two clusters. The number of desired clusters is passed +[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) +object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then output the parameters of the mixture model. {% highlight scala %} @@ -238,7 +184,7 @@ val gmm = new GaussianMixture().setK(2).run(parsedData) // output parameters of max-likelihood model for (i <- 0 until gmm.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma)) } @@ -298,7 +244,7 @@ public class GaussianMixtureExample {
    In the following example after loading and parsing data, we use a [GaussianMixture](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture) -object to cluster the data into two clusters. The number of desired clusters is passed +object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then output the parameters of the mixture model. {% highlight python %} @@ -322,11 +268,60 @@ for i in range(2):
    -#### Latent Dirichlet Allocation (LDA) Example +## Power iteration clustering (PIC) + +Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm: + +* accepts a [Graph](api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. +* calculates the principal eigenvalue and eigenvector +* Clusters each of the input points according to their principal eigenvector component value + +Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf} + +Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation: + +

    + The Property Graph + +

    + +## Latent Dirichlet allocation (LDA) + +[Latent Dirichlet allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) +is a topic model which infers topics from a collection of text documents. +LDA can be thought of as a clustering algorithm as follows: + +* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. +* Rather than estimating a clustering using a traditional distance, LDA uses a function based + on a statistical model of how text documents are generated. + +LDA takes in a collection of documents as vectors of word counts. +It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function. After fitting on the documents, LDA provides: + +* Topics: Inferred topics, each of which is a probability distribution over terms (words). +* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics. + +LDA takes the following parameters: + +* `k`: Number of topics (i.e., cluster centers) +* `maxIterations`: Limit on the number of iterations of EM used for learning +* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. +* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. +* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. + +*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet +support prediction on new documents, and it does not have a Python API. These will be added in the future. + +**Examples** In the following example, we load word count vectors representing a corpus of documents. We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) -to infer three topics from the documents. The number of desired clusters is passed +to infer three topics from the documents. The number of desired clusters is passed to the algorithm. We then output the topics, represented as probability distributions over words.
    @@ -419,42 +414,35 @@ public class JavaLDAExample {
    +## Streaming k-means -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -Quick Start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. - -## Streaming clustering - -When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, -with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm -uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: `\begin{equation} c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} \end{equation}` `\begin{equation} - n_{t+1} = n_t + m_t + n_{t+1} = n_t + m_t \end{equation}` -Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned -to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` -is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` -can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; -with `$\alpha$=0` only the most recent data will be used. This is analogous to an -exponentially-weighted moving average. +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. -The decay can be specified using a `halfLife` parameter, which determines the +The decay can be specified using a `halfLife` parameter, which determines the correct decay factor `a` such that, for data acquired at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. The unit of time can be specified either as `batches` or `points` and the update rule will be adjusted accordingly. -### Examples +**Examples** This example shows how to estimate clusters on streaming data. @@ -472,9 +460,9 @@ import org.apache.spark.mllib.clustering.StreamingKMeans {% endhighlight %} -Then we make an input stream of vectors for training, as well as a stream of labeled data -points for testing. We assume a StreamingContext `ssc` has been created, see -[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. {% highlight scala %} @@ -496,24 +484,24 @@ val model = new StreamingKMeans() {% endhighlight %} -Now register the streams for training and testing and start the job, printing +Now register the streams for training and testing and start the job, printing the predicted cluster assignments on new data points as they arrive. {% highlight scala %} model.trainOn(trainingData) -model.predictOnValues(testData).print() +model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} -As you add new text files with data the cluster centers will update. Each training +As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point -should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier -(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change!
    diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 3d32d03e35c62..fbe809b3478e5 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -21,12 +21,15 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [naive Bayes](mllib-naive-bayes.html) * [decision trees](mllib-decision-tree.html) * [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees) + * [isotonic regression](mllib-isotonic-regression.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) * [Clustering](mllib-clustering.html) - * k-means - * Gaussian mixture - * power iteration + * [k-means](mllib-clustering.html#k-means) + * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) + * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) + * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md new file mode 100644 index 0000000000000..12fb29d426741 --- /dev/null +++ b/docs/mllib-isotonic-regression.md @@ -0,0 +1,155 @@ +--- +layout: global +title: Naive Bayes - MLlib +displayTitle: MLlib - Regression +--- + +## Isotonic regression +[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression) +belongs to the family of regression algorithms. Formally isotonic regression is a problem where +given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses +and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted +finding a function that minimises + +`\begin{equation} + f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 +\end{equation}` + +with respect to complete order subject to +`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights. +The resulting function is called isotonic regression and it is unique. +It can be viewed as least squares problem under order restriction. +Essentially isotonic regression is a +[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) +best fitting the original data points. + +MLlib supports a +[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +which uses an approach to +[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +The training input is a RDD of tuples of three double values that represent +label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +optional parameter called $isotonic$ defaulting to true. +This argument specifies if the isotonic regression is +isotonic (monotonically increasing) or antitonic (monotonically decreasing). + +Training returns an IsotonicRegressionModel that can be used to predict +labels for both known and unknown features. The result of isotonic regression +is treated as piecewise linear function. The rules for prediction therefore are: + +* If the prediction input exactly matches a training feature + then associated prediction is returned. In case there are multiple predictions with the same + feature then one of them is returned. Which one is undefined + (same as java.util.Arrays.binarySearch). +* If the prediction input is lower or higher than all training features + then prediction with lowest or highest feature is returned respectively. + In case there are multiple predictions with the same feature + then the lowest or highest is returned respectively. +* If the prediction input falls between two training features then prediction is treated + as piecewise linear function and interpolated value is calculated from the + predictions of the two closest features. In case there are multiple values + with the same feature then the same rules as in previous point are used. + +### Examples + +
    +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight scala %} +import org.apache.spark.mllib.regression.IsotonicRegression + +val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) +} + +// Split data into training (60%) and test (40%) sets. +val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) +val training = splits(0) +val test = splits(1) + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +val model = new IsotonicRegression().setIsotonic(true).run(training) + +// Create tuples of predicted and real labels. +val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) +} + +// Calculate mean squared error between predicted and real labels. +val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() +println("Mean Squared Error = " + meanSquaredError) +{% endhighlight %} +
    + +
    +Data are read from a file where each line has a format label,feature +i.e. 4710.28,500.00. The data are split to training and testing set. +Model is created using the training set and a mean squared error is calculated from the predicted +labels and real labels in the test set. + +{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +import scala.Tuple2; +import scala.Tuple3; + +JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + +// Create label, feature, weight tuples from input data with weight set to default value 1.0. +JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } +); + +// Split data into training (60%) and test (40%) sets. +JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); +JavaRDD> training = splits[0]; +JavaRDD> test = splits[1]; + +// Create isotonic regression model from training data. +// Isotonic parameter defaults to true so it is only shown for demonstration +IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + +// Create tuples of predicted and real labels. +JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } +); + +// Calculate mean squared error between predicted and real labels. +Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } +).rdd()).mean(); + +System.out.println("Mean Squared Error = " + meanSquaredError); +{% endhighlight %} +
    +
    \ No newline at end of file diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b2b007509c735..8022c5ecc2430 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -14,10 +14,10 @@ title: Spark SQL Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of +[DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame). DataFrames are composed of [Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) +a schema that describes the data types of each column in the row. A DataFrame is similar to a table +in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`. @@ -27,10 +27,10 @@ All of the examples on this page use sample data included in the Spark distribut
    Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, -[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of +[DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame). DataFrames are composed of [Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with -a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table -in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) +a schema that describes the data types of each column in the row. A DataFrame is similar to a table +in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
    @@ -38,10 +38,10 @@ file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive]( Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using Spark. At the core of this component is a new type of RDD, -[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of +[DataFrame](api/python/pyspark.sql.html#pyspark.sql.DataFrame). DataFrames are composed of [Row](api/python/pyspark.sql.Row-class.html) objects, along with -a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table -in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) +a schema that describes the data types of each column in the row. A DataFrame is similar to a table +in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell. @@ -65,8 +65,8 @@ descendants. To create a basic SQLContext, all you need is a SparkContext. val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ {% endhighlight %} In addition to the basic SQLContext, you can also create a HiveContext, which provides a @@ -84,12 +84,12 @@ feature parity with a HiveContext.
    The entry point into all relational functionality in Spark is the -[JavaSQLContext](api/scala/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one -of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext. +[SQLContext](api/scala/index.html#org.apache.spark.sql.api.SQLContext) class, or one +of its descendants. To create a basic SQLContext, all you need is a JavaSparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); {% endhighlight %} In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict @@ -138,21 +138,21 @@ default is "hiveql", though "sql" is also available. Since the HiveQL parser is # Data Sources -Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section -describes the various methods for loading data into a SchemaRDD. +Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. +A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a DataFrame as a table allows you to run SQL queries over its data. This section +describes the various methods for loading data into a DataFrame. ## RDDs -Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. -The second method for creating SchemaRDDs is through a programmatic interface that allows you to +The second method for creating DataFrames is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows -you to construct SchemaRDDs when the columns and their types are not known until runtime. +you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection
    @@ -160,17 +160,17 @@ you to construct SchemaRDDs when the columns and their types are not known until
    The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes -to a SchemaRDD. The case class +to a DataFrame. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex -types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be +types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// this is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ // Define the schema using a case class. // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, @@ -184,7 +184,7 @@ people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -194,7 +194,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -230,7 +230,7 @@ for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -247,13 +247,13 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m }); // Apply a schema to an RDD of JavaBeans and register it as a table. -JavaSchemaRDD schemaPeople = sqlContext.createDataFrame(people, Person.class); +DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -267,7 +267,7 @@ List teenagerNames = teenagers.map(new Function() {
    -Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we @@ -284,11 +284,11 @@ lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) -# Infer the schema, and register the SchemaRDD as a table. +# Infer the schema, and register the DataFrame as a table. schemaPeople = sqlContext.inferSchema(people) schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -310,7 +310,7 @@ for teenName in teenNames.collect(): When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of @@ -341,15 +341,15 @@ val schema = val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) // Apply the schema to the RDD. -val peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema) +val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema) -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people") +// Register the DataFrames as a table. +peopleDataFrame.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val results = sqlContext.sql("SELECT name FROM people") -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. results.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -362,13 +362,13 @@ results.map(t => "Name: " + t(0)).collect().foreach(println) When JavaBean classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. 3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided -by `JavaSQLContext`. +by `SQLContext`. For example: {% highlight java %} @@ -381,7 +381,7 @@ import org.apache.spark.sql.api.java.StructField import org.apache.spark.sql.api.java.Row // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); @@ -406,15 +406,15 @@ JavaRDD rowRDD = people.map( }); // Apply the schema to the RDD. -JavaSchemaRDD peopleSchemaRDD = sqlContext.createDataFrame(rowRDD, schema); +DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema); -// Register the SchemaRDD as a table. -peopleSchemaRDD.registerTempTable("people"); +// Register the DataFrame as a table. +peopleDataFrame.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. -JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); +DataFrame results = sqlContext.sql("SELECT name FROM people"); -// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List names = results.map(new Function() { public String call(Row row) { @@ -431,7 +431,7 @@ List names = results.map(new Function() { When a dictionary of kwargs cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `SchemaRDD` can be created programmatically with three steps. +a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of @@ -460,10 +460,10 @@ schema = StructType(fields) # Apply the schema to the RDD. schemaPeople = sqlContext.createDataFrame(people, schema) -# Register the SchemaRDD as a table. +# Register the DataFrame as a table. schemaPeople.registerTempTable("people") -# SQL can be run over SchemaRDDs that have been registered as a table. +# SQL can be run over DataFrames that have been registered as a table. results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. @@ -493,16 +493,16 @@ Using the data from the above example: {% highlight scala %} // sqlContext from the previous example is used in this example. -// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD. -import sqlContext.createSchemaRDD +// This is used to implicitly convert an RDD to a DataFrame. +import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. -// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet. +// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a SchemaRDD. +// The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. @@ -518,18 +518,18 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% highlight java %} // sqlContext from the previous example is used in this example. -JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. +DataFrame schemaPeople = ... // The DataFrame from the previous example. -// JavaSchemaRDDs can be saved as Parquet files, maintaining the schema information. +// DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a JavaSchemaRDD. -JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); +// The result of loading a parquet file is also a DataFrame. +DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); @@ -544,13 +544,13 @@ List teenagerNames = teenagers.map(new Function() { {% highlight python %} # sqlContext from the previous example is used in this example. -schemaPeople # The SchemaRDD from the previous example. +schemaPeople # The DataFrame from the previous example. -# SchemaRDDs can be saved as Parquet files, maintaining the schema information. +# DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a SchemaRDD. +# The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. @@ -629,7 +629,7 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. This conversion can be done using one of two methods in a SQLContext: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. @@ -646,7 +646,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a SchemaRDD from the file(s) pointed to by path +// Create a DataFrame from the file(s) pointed to by path val people = sqlContext.jsonFile(path) // The inferred schema can be visualized using the printSchema() method. @@ -655,13 +655,13 @@ people.printSchema() // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this SchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -// Alternatively, a SchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) @@ -671,8 +671,8 @@ val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD. -This conversion can be done using one of two methods in a JavaSQLContext : +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. +This conversion can be done using one of two methods in a SQLContext : * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. * `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. @@ -683,13 +683,13 @@ a regular multi-line JSON file will most often fail. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; -// Create a JavaSchemaRDD from the file(s) pointed to by path -JavaSchemaRDD people = sqlContext.jsonFile(path); +// Create a DataFrame from the file(s) pointed to by path +DataFrame people = sqlContext.jsonFile(path); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -697,23 +697,23 @@ people.printSchema(); // |-- age: integer (nullable = true) // |-- name: string (nullable = true) -// Register this JavaSchemaRDD as a table. +// Register this DataFrame as a table. people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. -JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); +DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); -// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by +// Alternatively, a DataFrame can be created for a JSON dataset represented by // an RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); {% endhighlight %}
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD. +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. This conversion can be done using one of two methods in a SQLContext: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. @@ -731,7 +731,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = "examples/src/main/resources/people.json" -# Create a SchemaRDD from the file(s) pointed to by path +# Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # The inferred schema can be visualized using the printSchema() method. @@ -740,13 +740,13 @@ people.printSchema() # |-- age: integer (nullable = true) # |-- name: string (nullable = true) -# Register this SchemaRDD as a table. +# Register this DataFrame as a table. people.registerTempTable("people") # SQL statements can be run by using the sql methods provided by sqlContext. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") -# Alternatively, a SchemaRDD can be created for a JSON dataset represented by +# Alternatively, a DataFrame can be created for a JSON dataset represented by # an RDD[String] storing one JSON object per string. anotherPeopleRDD = sc.parallelize([ '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) @@ -792,14 +792,14 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
    -When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -841,7 +841,7 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. @@ -1161,7 +1161,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +[ScalaDoc](api/scala/index.html#org.apache.spark.sql.DataFrame). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96fb12ce5e0b9..997de9511ca3e 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -878,6 +878,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Scala code, take a look at the example +[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache +/spark/examples/streaming/StatefulNetworkWordCount.scala). +
    @@ -899,6 +905,13 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} +The update function will be called for each word, with `newValues` having a sequence of 1's (from +the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete +Java code, take a look at the example +[JavaStatefulNetworkWordCount.java]({{site +.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +/JavaStatefulNetworkWordCount.java). +
    @@ -916,14 +929,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi runningCounts = pairs.updateStateByKey(updateFunction) {% endhighlight %} -
    -
    - The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example +Python code, take a look at the example [stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +
    +
    + Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is discussed in detail in the [checkpointing](#checkpointing) section. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0ea7365d75b83..c59ab565c6862 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -42,7 +42,7 @@ from optparse import OptionParser from sys import stderr -SPARK_EC2_VERSION = "1.2.0" +SPARK_EC2_VERSION = "1.2.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -58,6 +58,7 @@ "1.1.0", "1.1.1", "1.2.0", + "1.2.1", ]) DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -1173,11 +1174,12 @@ def real_main(): time.sleep(30) # Yes, it does have to be this long :-( for group in groups: try: - conn.delete_security_group(group.name) - print "Deleted security group " + group.name + # It is needed to use group_id to make it work with VPC + conn.delete_security_group(group_id=group.id) + print "Deleted security group %s" % group.name except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group " + group.name + print "Failed to delete security group %s" % group.name # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java new file mode 100644 index 0000000000000..d46c7107c7a21 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.StorageLevels; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every + * second starting with initial value of word count. + * Usage: JavaStatefulNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. + *

    + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example + * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999` + */ +public class JavaStatefulNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaStatefulNetworkWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + // Update the cumulative count function + final Function2, Optional, Optional> updateFunction = + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + Integer newSum = state.or(0); + for (Integer value : values) { + newSum += value; + } + return Optional.of(newSum); + } + }; + + // Create the context with a 1 second batch size + SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint("."); + + // Initial RDD input to updateStateByKey + List> tuples = Arrays.asList(new Tuple2("hello", 1), + new Tuple2("world", 1)); + JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples); + + JavaReceiverInputDStream lines = ssc.socketTextStream( + args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2); + + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + + JavaPairDStream wordsDstream = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + // This will give a Dstream made of state (which is the cumulative count of the words) + JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, + new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD); + + stateDstream.print(); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index c7df3d7b74767..b4d9355b681f6 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -36,43 +36,33 @@ sqlCtx = SQLContext(sc) # Prepare training documents, which are labeled. - LabeledDocument = Row('id', 'text', 'label') - training = sqlCtx.inferSchema( - sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) - .map(lambda x: LabeledDocument(*x))) + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) \ + .map(lambda x: LabeledDocument(*x)).toDF() # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. - tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - hashingTF = HashingTF() \ - .setInputCol(tokenizer.getOutputCol()) \ - .setOutputCol("features") - lr = LogisticRegression() \ - .setMaxIter(10) \ - .setRegParam(0.01) - pipeline = Pipeline() \ - .setStages([tokenizer, hashingTF, lr]) + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) # Prepare test documents, which are unlabeled. - Document = Row('id', 'text') - test = sqlCtx.inferSchema( - sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) - .map(lambda x: Document(*x))) + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() # Make predictions on test documents and print columns of interest. prediction = model.transform(test) - prediction.registerTempTable("prediction") - selected = sqlCtx.sql("SELECT id, text, prediction from prediction") + selected = prediction.select("id", "text", "prediction") for row in selected.collect(): print row diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index a2893f78e0fec..f0241943ef410 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -90,7 +90,7 @@ object CrossValidatorExample { crossval.setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. - val cvModel = crossval.fit(training) + val cvModel = crossval.fit(training.toDF) // Prepare test documents, which are unlabeled. val test = sc.parallelize(Seq( @@ -100,7 +100,7 @@ object CrossValidatorExample { Document(7L, "apache hadoop"))) // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test) + cvModel.transform(test.toDF) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index aed44238939c7..54aadd2288817 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -58,7 +58,7 @@ object DeveloperApiExample { lr.setMaxIter(10) // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model = lr.fit(training) + val model = lr.fit(training.toDF) // Prepare test data. val test = sc.parallelize(Seq( @@ -67,7 +67,7 @@ object DeveloperApiExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data. - val sumPredictions: Double = model.transform(test) + val sumPredictions: Double = model.transform(test.toDF) .select("features", "label", "prediction") .collect() .map { case Row(features: Vector, label: Double, prediction: Double) => diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 836ea2e01201e..adaf796dc1896 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -137,9 +137,9 @@ object MovieLensALS { .setRegParam(params.regParam) .setNumBlocks(params.numBlocks) - val model = als.fit(training) + val model = als.fit(training.toDF) - val predictions = model.transform(test).cache() + val predictions = model.transform(test.toDF).cache() // Evaluate the model. // TODO: Create an evaluator to compute RMSE. @@ -158,7 +158,7 @@ object MovieLensALS { // Inspect false positives. predictions.registerTempTable("prediction") - sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie") + sc.textFile(params.movies).map(Movie.parseMovie).toDF.registerTempTable("movie") sqlContext.sql( """ |SELECT userId, prediction.movieId, title, rating, prediction diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 80c9f5ff5781e..c5bb5515b1930 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -58,7 +58,7 @@ object SimpleParamsExample { .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model1 = lr.fit(training) + val model1 = lr.fit(training.toDF) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -77,7 +77,7 @@ object SimpleParamsExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. - val model2 = lr.fit(training, paramMapCombined) + val model2 = lr.fit(training.toDF, paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) // Prepare test data. @@ -90,7 +90,7 @@ object SimpleParamsExample { // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. - model2.transform(test) + model2.transform(test.toDF) .select("features", "label", "myProbability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 968cb292120d8..8b47f88e48df1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -69,7 +69,7 @@ object SimpleTextClassificationPipeline { .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. - val model = pipeline.fit(training) + val model = pipeline.fit(training.toDF) // Prepare test documents, which are unlabeled. val test = sc.parallelize(Seq( @@ -79,7 +79,7 @@ object SimpleTextClassificationPipeline { Document(7L, "apache hadoop"))) // Make predictions on test documents. - model.transform(test) + model.transform(test.toDF) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index 89b6255991a38..c98c68a02f2be 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -81,18 +81,18 @@ object DatasetExample { println(s"Loaded ${origData.count()} instances from file: ${params.input}") // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDataFrame + val df: DataFrame = origData.toDF println(s"Inferred schema:\n${df.schema.prettyJson}") println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to DataFrames. - val labelsDf: DataFrame = origData.select("label") + // Select columns + val labelsDf: DataFrame = df.select("label") val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresDf: DataFrame = origData.select("features") + val featuresDf: DataFrame = df.select("features") val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..b2373adba1fd4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.mllib.clustering.PowerIterationClustering +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. + * Takes an input of K concentric circles and the number of points in the innermost circle. + * The output should be K clusters - each cluster containing precisely the points associated + * with each of the input circles. + * + * Run with + * {{{ + * ./bin/run-example mllib.PowerIterationClusteringExample [options] + * + * Where options include: + * k: Number of circles/clusters + * n: Number of sampled points on innermost circle.. There are proportionally more points + * within the outer/larger circles + * maxIterations: Number of Power Iterations + * outerRadius: radius of the outermost of the concentric circles + * }}} + * + * Here is a sample run and output: + * + * ./bin/run-example mllib.PowerIterationClusteringExample + * -k 3 --n 30 --maxIterations 15 + * + * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], + * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * + * + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object PowerIterationClusteringExample { + + case class Params( + input: String = null, + k: Int = 3, + numPoints: Int = 5, + maxIterations: Int = 10, + outerRadius: Double = 3.0 + ) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("PIC Circles") { + head("PowerIterationClusteringExample: an example PIC app using concentric circles.") + opt[Int]('k', "k") + .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]('n', "n") + .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") + .action((x, c) => c.copy(numPoints = x)) + opt[Int]("maxIterations") + .text(s"number of iterations, default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Int]('r', "r") + .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") + .action((x, c) => c.copy(numPoints = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf() + .setMaster("local") + .setAppName(s"PowerIterationClustering with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + val model = new PowerIterationClustering() + .setK(params.k) + .setMaxIterations(params.maxIterations) + .run(circlesRdd) + + val clusters = model.assignments.collect.groupBy(_._2).mapValues(_.map(_._1)) + val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignmentsStr = assignments + .map { case (k, v) => + s"$k -> ${v.sorted.mkString("[", ",", "]")}" + }.mkString(",") + val sizesStr = assignments.map { + _._2.size + }.sorted.mkString("(", ",", ")") + println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") + + sc.stop() + } + + def generateCircle(radius: Double, n: Int) = { + Seq.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (radius * math.cos(theta), radius * math.sin(theta)) + } + } + + def generateCirclesRdd(sc: SparkContext, + nCircles: Int = 3, + nPoints: Int = 30, + outerRadius: Double): RDD[(Long, Long, Double)] = { + + val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} + val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} + val points = (0 until nCircles).flatMap { cx => + generateCircle(radii(cx), groupSizes(cx)) + }.zipWithIndex + val rdd = sc.parallelize(points) + val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => + if (i0 < i1) { + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + } else { + None + } + } + distancesRdd + } + + /** + * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel + */ + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = { + val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) + val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) + coeff * math.exp(expCoeff * ssquares) + // math.exp((p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2)) + } + + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 1eac3c8d03e39..79d3d5a24ceaf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -34,10 +34,10 @@ object RDDRelation { // Importing the SQL context gives access to all the SQL functions and implicit conversions. import sqlContext.implicits._ - val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerTempTable("records") + df.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") @@ -55,10 +55,10 @@ object RDDRelation { rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) + df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - rdd.saveAsParquetFile("pair.parquet") + df.saveAsParquetFile("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. val parquetFile = sqlContext.parquetFile("pair.parquet") diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 15754cdfcc35e..7128deba54da7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -68,7 +68,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerTempTable("records") + rdd.toDF.registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b25c2120d54f7..926094449e7fc 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -67,7 +67,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } - test("basic stream receiving with multiple topics and smallest starting offset") { + ignore("basic stream receiving with multiple topics and smallest starting offset") { val topics = Set("basic1", "basic2", "basic3") val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => @@ -113,7 +113,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase ssc.stop() } - test("receiving from largest starting offset") { + ignore("receiving from largest starting offset") { val topic = "largest" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) @@ -158,7 +158,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } - test("creating stream by offset") { + ignore("creating stream by offset") { val topic = "offset" val topicPartition = TopicAndPartition(topic, 0) val data = Map("a" -> 10) @@ -204,7 +204,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase } // Test to verify the offset ranges can be recovered from the checkpoints - test("offset recovery") { + ignore("offset recovery") { val topic = "recovery" createTopic(topic) testDir = Utils.createTempDir() diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c815eda52bda7..216661b8bc73a 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -67,11 +67,6 @@ scalacheck_${scala.binary.version} test - - org.easymock - easymockclassextension - test - com.novocode junit-interface diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 41dbd64c2b1fa..f56898af029c1 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,7 +20,6 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.annotation.Experimental import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Milliseconds import org.apache.spark.streaming.Seconds @@ -28,9 +27,11 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.util.Clock import org.apache.spark.streaming.util.ManualClock + +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.scalatest.mock.EasyMockSugar +import org.scalatest.mock.MockitoSugar import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException @@ -42,10 +43,10 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record /** - * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with EasyMockSugar { + with MockitoSugar { val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -73,6 +74,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } + override def afterFunction(): Unit = { + super.afterFunction() + // Since this suite was originally written using EasyMock, add this to preserve the old + // mocking semantics (see SPARK-5735 for more details) + verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, + checkpointStateMock, currentClockMock) + } + test("kinesis utils api") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving @@ -83,193 +92,175 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft } test("process records including store and checkpoint") { - val expectedCheckpointIntervalMillis = 10 - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).once() - receiverMock.store(record2.getData().array()).once() - checkpointStateMock.shouldCheckpoint().andReturn(true).once() - checkpointerMock.checkpoint().once() - checkpointStateMock.advanceCheckpoint().once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) + verify(receiverMock, times(1)).store(record2.getData().array()) + verify(checkpointStateMock, times(1)).shouldCheckpoint() + verify(checkpointerMock, times(1)).checkpoint() + verify(checkpointStateMock, times(1)).advanceCheckpoint() } test("shouldn't store and checkpoint when receiver is stopped") { - expecting { - receiverMock.isStopped().andReturn(true).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() } test("shouldn't checkpoint when exception occurs during store") { - expecting { - receiverMock.isStopped().andReturn(false).once() - receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() - } - whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { - intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - } + when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException()) + + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) } + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).store(record1.getData().array()) } test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { + when(currentClockMock.currentTime()).thenReturn(0) + val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - } + + verify(currentClockMock, times(1)).currentTime() } test("should checkpoint if we have exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - } + when(currentClockMock.currentTime()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).currentTime() } test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - } + when(currentClockMock.currentTime()).thenReturn(0) + + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + + verify(currentClockMock, times(1)).currentTime() } test("should add to time when advancing checkpoint") { - expecting { - currentClockMock.currentTime().andReturn(0).once() - } - whenExecuting(currentClockMock) { - val checkpointIntervalMillis = 10 - val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) - } + when(currentClockMock.currentTime()).thenReturn(0) + + val checkpointIntervalMillis = 10 + val checkpointState = + new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) + + verify(currentClockMock, times(1)).currentTime() } test("shutdown should checkpoint if the reason is TERMINATE") { - expecting { - checkpointerMock.checkpoint().once() - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - val reason = ShutdownReason.TERMINATE - recordProcessor.shutdown(checkpointerMock, reason) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + + verify(checkpointerMock, times(1)).checkpoint() } test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - expecting { - } - whenExecuting(checkpointerMock, checkpointStateMock) { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, - checkpointStateMock) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) - } + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + + verify(checkpointerMock, never()).checkpoint() } test("retry success on first attempt") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()).thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(1)).isStopped() } test("retry success on second attempt after a Kinesis throttling exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new ThrottlingException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new ThrottlingException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry success on second attempt after a Kinesis dependency exception") { val expectedIsStopped = false - expecting { - receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) - .andReturn(expectedIsStopped).once() - } - whenExecuting(receiverMock) { - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - } + when(receiverMock.isStopped()) + .thenThrow(new KinesisClientLibDependencyException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() } test("retry failed after a shutdown exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[ShutdownException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) + + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after an invalid state exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[InvalidStateException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) + + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after unexpected exception") { - expecting { - checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() - } - whenExecuting(checkpointerMock) { - intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } + when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) + + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + + verify(checkpointerMock, times(1)).checkpoint() } test("retry failed after exhausing all retries") { val expectedErrorMessage = "final try error message" - expecting { - checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) - .andThrow(new ThrottlingException(expectedErrorMessage)).once() - } - whenExecuting(checkpointerMock) { - val exception = intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - exception.getMessage().shouldBe(expectedErrorMessage) + when(checkpointerMock.checkpoint()) + .thenThrow(new ThrottlingException("error message")) + .thenThrow(new ThrottlingException(expectedErrorMessage)) + + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) } + exception.getMessage().shouldBe(expectedErrorMessage) + + verify(checkpointerMock, times(2)).checkpoint() } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index f58587e10a820..fc84cfbe64184 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -17,6 +17,8 @@ package org.apache.spark.graphx.lib +import org.apache.spark.annotation.Experimental + import scala.util.Random import org.jblas.DoubleMatrix import org.apache.spark.rdd._ @@ -38,6 +40,8 @@ object SVDPlusPlus { extends Serializable /** + * :: Experimental :: + * * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", * available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]]. @@ -45,12 +49,33 @@ object SVDPlusPlus { * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)), * see the details on page 6. * + * This method temporarily replaces `run()`, and replaces `DoubleMatrix` in `run()`'s return + * value with `Array[Double]`. In 1.4.0, this method will be deprecated, but will be copied + * to replace `run()`, which will then be undeprecated. + * * @param edges edges for constructing the graph * * @param conf SVDPlusPlus parameters * * @return a graph with vertex attributes containing the trained model */ + @Experimental + def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) + : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = + { + val (graph, u) = run(edges, conf) + // Convert DoubleMatrix to Array[Double]: + val newVertices = graph.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4)) + (Graph(newVertices, graph.edges), u) + } + + /** + * This method is deprecated in favor of `runSVDPlusPlus()`, which replaces `DoubleMatrix` + * with `Array[Double]` in its return value. This method is deprecated. It will effectively + * be removed in 1.4.0 when `runSVDPlusPlus()` is copied to replace `run()`, and hence the + * return type of this method changes. + */ + @deprecated("Call runSVDPlusPlus", "1.3.0") def run(edges: RDD[Edge[Double]], conf: Conf) : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) = { @@ -72,17 +97,22 @@ object SVDPlusPlus { // construct graph var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() + materialize(g) + edges.unpersist() // Calculate initial bias and norm val t0 = g.aggregateMessages[(Long, Double)]( ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { + val gJoinT0 = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) - } + }.cache() + materialize(gJoinT0) + g.unpersist() + g = gJoinT0 def sendMsgTrainF(conf: Conf, u: Double) (ctx: EdgeContext[ @@ -114,12 +144,15 @@ object SVDPlusPlus { val t1 = g.aggregateMessages[DoubleMatrix]( ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { + val gJoinT1 = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => if (msg.isDefined) (vd._1, vd._1 .addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd - } + }.cache() + materialize(gJoinT1) + g.unpersist() + g = gJoinT1 // Phase 2, update p for user nodes and q, y for item nodes g.cache() @@ -127,13 +160,16 @@ object SVDPlusPlus { sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + val gJoinT2 = g.outerJoinVertices(t2) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2), vd._3 + msg.get._3, vd._4) - } + }.cache() + materialize(gJoinT2) + g.unpersist() + g = gJoinT2 } // calculate error on training set @@ -147,13 +183,26 @@ object SVDPlusPlus { val err = (ctx.attr - pred) * (ctx.attr - pred) ctx.sendToDst(err) } + g.cache() val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) - g = g.outerJoinVertices(t3) { + val gJoinT3 = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd - } + }.cache() + materialize(gJoinT3) + g.unpersist() + g = gJoinT3 (g, u) } + + /** + * Forces materialization of a Graph by count()ing its RDDs. + */ + private def materialize(g: Graph[_,_]): Unit = { + g.vertices.count() + g.edges.count() + } + } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index e01df56e94de9..9987a4b1a3c25 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -32,7 +32,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble) } val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations - var (graph, u) = SVDPlusPlus.run(edges, conf) + var (graph, u) = SVDPlusPlus.runSVDPlusPlus(edges, conf) graph.cache() val err = graph.vertices.collect().map{ case (vid, vd) => if (vid % 2 == 1) vd._4 else 0.0 diff --git a/make-distribution.sh b/make-distribution.sh index 051c87c0894ae..dd990d4b96e46 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -98,9 +98,9 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found if [ $(command -v rpm) ]; then - RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) + RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" if [ "$RPM_JAVA_HOME" != "%java_home" ]; then - JAVA_HOME=$RPM_JAVA_HOME + JAVA_HOME="$RPM_JAVA_HOME" echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" fi fi @@ -113,24 +113,24 @@ fi if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) - if [ ! -z $GITREV ]; then + if [ ! -z "$GITREV" ]; then GITREVSTRING=" (git revision $GITREV)" fi unset GITREV fi -if [ ! $(command -v $MVN) ] ; then +if [ ! $(command -v "$MVN") ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$($MVN help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$($MVN help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ +SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ @@ -147,7 +147,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then echo "Output from 'java -version' was:" echo "$JAVA_VERSION" read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! $REPLY =~ ^[Yy]$ ]]; then + if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then echo "Okay, exiting." exit 1 fi @@ -232,7 +232,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ "$SPARK_TACHYON" == "true" ]; then TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - pushd $TMPD > /dev/null + pushd "$TMPD" > /dev/null echo "Fetching tachyon tgz" TACHYON_DL="${TACHYON_TGZ}.part" @@ -259,7 +259,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi popd > /dev/null - rm -rf $TMPD + rm -rf "$TMPD" fi if [ "$MAKE_TGZ" == "true" ]; then diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index bb291e6e1fd7d..5607ed21afe18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -114,7 +114,9 @@ class Pipeline extends Estimator[PipelineModel] { throw new IllegalArgumentException( s"Do not support stage $stage of type ${stage.getClass}") } - curDataset = transformer.transform(curDataset, paramMap) + if (index < indexOfLastEstimator) { + curDataset = transformer.transform(curDataset, paramMap) + } transformers += transformer } else { transformers += stage.asInstanceOf[Transformer] diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index cd95c16aa768d..9a5848684b179 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -62,7 +62,10 @@ abstract class Transformer extends PipelineStage with Params { private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { + /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + + /** @group setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** @@ -97,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap - dataset.select($"*", callUDF( - this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) + dataset.withColumn(map(outputCol), + callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 1bf8eb4640d11..c5fc89f935432 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -66,6 +66,7 @@ private[spark] abstract class Classifier[ extends Predictor[FeaturesType, E, M] with ClassifierParams { + /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] @@ -87,6 +88,7 @@ private[spark] abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] extends PredictionModel[FeaturesType, M] with ClassifierParams { + /** @group setParam */ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] /** Number of classes (values which the label can take). */ @@ -180,24 +182,22 @@ private[ml] object ClassificationModel { if (map(model.rawPredictionCol) != "") { // output raw prediction val features2raw: FeaturesType => Vector = model.predictRaw - tmpData = tmpData.select($"*", - callUDF(features2raw, new VectorUDT, - col(map(model.featuresCol))).as(map(model.rawPredictionCol))) + tmpData = tmpData.withColumn(map(model.rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(model.featuresCol)))) numColsOutput += 1 if (map(model.predictionCol) != "") { val raw2pred: Vector => Double = (rawPred) => { rawPred.toArray.zipWithIndex.maxBy(_._1)._2 } - tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType, - col(map(model.rawPredictionCol))).as(map(model.predictionCol))) + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol)))) numColsOutput += 1 } } else if (map(model.predictionCol) != "") { // output prediction val features2pred: FeaturesType => Double = model.predict - tmpData = tmpData.select($"*", - callUDF(features2pred, DoubleType, - col(map(model.featuresCol))).as(map(model.predictionCol))) + tmpData = tmpData.withColumn(map(model.predictionCol), + callUDF(features2pred, DoubleType, col(map(model.featuresCol)))) numColsOutput += 1 } (numColsOutput, tmpData) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c146fe244c66e..21f61d80dd95a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -49,8 +49,13 @@ class LogisticRegression setMaxIter(100) setThreshold(0.5) + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { @@ -93,6 +98,7 @@ class LogisticRegressionModel private[ml] ( setThreshold(0.5) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) private val margin: Vector => Double = (features) => { @@ -124,44 +130,39 @@ class LogisticRegressionModel private[ml] ( var numColsOutput = 0 if (map(rawPredictionCol) != "") { val features2raw: Vector => Vector = (features) => predictRaw(features) - tmpData = tmpData.select($"*", - callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol))) + tmpData = tmpData.withColumn(map(rawPredictionCol), + callUDF(features2raw, new VectorUDT, col(map(featuresCol)))) numColsOutput += 1 } if (map(probabilityCol) != "") { if (map(rawPredictionCol) != "") { - val raw2prob: Vector => Vector = { (rawPreds: Vector) => + val raw2prob = udf { (rawPreds: Vector) => val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - Vectors.dense(1.0 - prob1, prob1) + Vectors.dense(1.0 - prob1, prob1): Vector } - tmpData = tmpData.select($"*", - callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol))) + tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol)))) } else { - val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features) - tmpData = tmpData.select($"*", - callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol))) + val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } + tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol)))) } numColsOutput += 1 } if (map(predictionCol) != "") { val t = map(threshold) if (map(probabilityCol) != "") { - val predict: Vector => Double = { probs: Vector => + val predict = udf { probs: Vector => if (probs(1) > t) 1.0 else 0.0 } - tmpData = tmpData.select($"*", - callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol))) + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol)))) } else if (map(rawPredictionCol) != "") { - val predict: Vector => Double = { rawPreds: Vector => + val predict = udf { rawPreds: Vector => val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) if (prob1 > t) 1.0 else 0.0 } - tmpData = tmpData.select($"*", - callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol))) + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol)))) } else { - val predict: Vector => Double = (features: Vector) => this.predict(features) - tmpData = tmpData.select($"*", - callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol))) + val predict = udf { features: Vector => this.predict(features) } + tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol)))) } numColsOutput += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1202528ca654e..bd8caac855981 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -61,6 +61,7 @@ private[spark] abstract class ProbabilisticClassifier[ M <: ProbabilisticClassificationModel[FeaturesType, M]] extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams { + /** @group setParam */ def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] } @@ -82,6 +83,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ M <: ProbabilisticClassificationModel[FeaturesType, M]] extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + /** @group setParam */ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] /** @@ -120,8 +122,8 @@ private[spark] abstract class ProbabilisticClassificationModel[ val features2probs: FeaturesType => Vector = (features) => { tmpModel.predictProbabilities(features) } - outputData.select($"*", - callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol))) + outputData.withColumn(map(probabilityCol), + callUDF(features2probs, new VectorUDT, col(map(featuresCol)))) } else { if (numColsOutput == 0) { this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index f21a30627e540..2360f4479f1c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -35,13 +35,23 @@ import org.apache.spark.sql.types.DoubleType class BinaryClassificationEvaluator extends Evaluator with Params with HasRawPredictionCol with HasLabelCol { - /** param for metric name in evaluation */ + /** + * param for metric name in evaluation + * @group param + */ val metricName: Param[String] = new Param(this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + + /** @group getParam */ def getMetricName: String = get(metricName) + + /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) + /** @group setParam */ def setScoreCol(value: String): this.type = set(rawPredictionCol, value) + + /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0956062643f23..6131ba8832691 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -31,11 +31,18 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { - /** number of features */ + /** + * number of features + * @group param + */ val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) - def setNumFeatures(value: Int) = set(numFeatures, value) + + /** @group getParam */ def getNumFeatures: Int = get(numFeatures) + /** @group setParam */ + def setNumFeatures(value: Int) = set(numFeatures, value) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 4745a7ae95679..ddbd648d64f23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -39,7 +39,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { @@ -75,14 +78,17 @@ class StandardScalerModel private[ml] ( scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) - dataset.select($"*", scale(col(map(inputCol))).as(map(outputCol))) + dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index 89b53f3890ea3..7daeff980f0ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -85,8 +85,13 @@ private[spark] abstract class Predictor[ M <: PredictionModel[FeaturesType, M]] extends Estimator[M] with PredictorParams { + /** @group setParam */ def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] override def fit(dataset: DataFrame, paramMap: ParamMap): M = { @@ -160,8 +165,10 @@ private[spark] abstract class Predictor[ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] extends Model[M] with PredictorParams { + /** @group setParam */ def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + /** @group setParam */ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] /** @@ -209,7 +216,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel val pred: FeaturesType => Double = (features) => { tmpModel.predict(features) } - dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol))) + dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index 51cd48c90432a..b45bd1499b72e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -20,5 +20,19 @@ package org.apache.spark /** * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. + * + * @groupname param Parameters + * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get + * the parameter values through setters and getters, respectively. + * @groupprio param -5 + * + * @groupname setParam Parameter setters + * @groupprio setParam 5 + * + * @groupname getParam Parameter getters + * @groupprio getParam 6 + * + * @groupname Ungrouped Members + * @groupprio Ungrouped 0 */ package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index 32fc74462ef4a..1a70322b4cace 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -24,67 +24,117 @@ package org.apache.spark.ml.param */ private[ml] trait HasRegParam extends Params { - /** param for regularization parameter */ + /** + * param for regularization parameter + * @group param + */ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ def getRegParam: Double = get(regParam) } private[ml] trait HasMaxIter extends Params { - /** param for max number of iterations */ + /** + * param for max number of iterations + * @group param + */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ def getMaxIter: Int = get(maxIter) } private[ml] trait HasFeaturesCol extends Params { - /** param for features column name */ + /** + * param for features column name + * @group param + */ val featuresCol: Param[String] = new Param(this, "featuresCol", "features column name", Some("features")) + + /** @group getParam */ def getFeaturesCol: String = get(featuresCol) } private[ml] trait HasLabelCol extends Params { - /** param for label column name */ + /** + * param for label column name + * @group param + */ val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) + + /** @group getParam */ def getLabelCol: String = get(labelCol) } private[ml] trait HasPredictionCol extends Params { - /** param for prediction column name */ + /** + * param for prediction column name + * @group param + */ val predictionCol: Param[String] = new Param(this, "predictionCol", "prediction column name", Some("prediction")) + + /** @group getParam */ def getPredictionCol: String = get(predictionCol) } private[ml] trait HasRawPredictionCol extends Params { - /** param for raw prediction column name */ + /** + * param for raw prediction column name + * @group param + */ val rawPredictionCol: Param[String] = new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", Some("rawPrediction")) + + /** @group getParam */ def getRawPredictionCol: String = get(rawPredictionCol) } private[ml] trait HasProbabilityCol extends Params { - /** param for predicted class conditional probabilities column name */ + /** + * param for predicted class conditional probabilities column name + * @group param + */ val probabilityCol: Param[String] = new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", Some("probability")) + + /** @group getParam */ def getProbabilityCol: String = get(probabilityCol) } private[ml] trait HasThreshold extends Params { - /** param for threshold in (binary) prediction */ + /** + * param for threshold in (binary) prediction + * @group param + */ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + + /** @group getParam */ def getThreshold: Double = get(threshold) } private[ml] trait HasInputCol extends Params { - /** param for input column name */ + /** + * param for input column name + * @group param + */ val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + + /** @group getParam */ def getInputCol: String = get(inputCol) } private[ml] trait HasOutputCol extends Params { - /** param for output column name */ + /** + * param for output column name + * @group param + */ val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + + /** @group getParam */ def getOutputCol: String = get(outputCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index bf5737177ceee..8d70e4347c4c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -36,7 +36,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -49,43 +49,89 @@ import org.apache.spark.util.random.XORShiftRandom private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam with HasPredictionCol { - /** Param for rank of the matrix factorization. */ + /** + * Param for rank of the matrix factorization. + * @group param + */ val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + + /** @group getParam */ def getRank: Int = get(rank) - /** Param for number of user blocks. */ + /** + * Param for number of user blocks. + * @group param + */ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + + /** @group getParam */ def getNumUserBlocks: Int = get(numUserBlocks) - /** Param for number of item blocks. */ + /** + * Param for number of item blocks. + * @group param + */ val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + + /** @group getParam */ def getNumItemBlocks: Int = get(numItemBlocks) - /** Param to decide whether to use implicit preference. */ + /** + * Param to decide whether to use implicit preference. + * @group param + */ val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + + /** @group getParam */ def getImplicitPrefs: Boolean = get(implicitPrefs) - /** Param for the alpha parameter in the implicit preference formulation. */ + /** + * Param for the alpha parameter in the implicit preference formulation. + * @group param + */ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + + /** @group getParam */ def getAlpha: Double = get(alpha) - /** Param for the column name for user ids. */ + /** + * Param for the column name for user ids. + * @group param + */ val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + + /** @group getParam */ def getUserCol: String = get(userCol) - /** Param for the column name for item ids. */ + /** + * Param for the column name for item ids. + * @group param + */ val itemCol = new Param[String](this, "itemCol", "column name for item ids", Some("item")) + + /** @group getParam */ def getItemCol: String = get(itemCol) - /** Param for the column name for ratings. */ + /** + * Param for the column name for ratings. + * @group param + */ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + + /** @group getParam */ def getRatingCol: String = get(ratingCol) + /** + * Param for whether to apply nonnegativity constraints. + * @group param + */ val nonnegative = new BooleanParam( this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + + /** @group getParam */ val getNonnegative: Boolean = get(nonnegative) /** @@ -124,8 +170,8 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ val map = this.paramMap ++ paramMap - val users = userFactors.toDataFrame("id", "features") - val items = itemFactors.toDataFrame("id", "features") + val users = userFactors.toDF("id", "features") + val items = itemFactors.toDF("id", "features") // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. @@ -181,20 +227,46 @@ class ALS extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating + /** @group setParam */ def setRank(value: Int): this.type = set(rank, value) + + /** @group setParam */ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) + + /** @group setParam */ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) + + /** @group setParam */ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) + + /** @group setParam */ def setAlpha(value: Double): this.type = set(alpha, value) + + /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ def setItemCol(value: String): this.type = set(itemCol, value) + + /** @group setParam */ def setRatingCol(value: String): this.type = set(ratingCol, value) + + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setNonnegative(value: Boolean): this.type = set(nonnegative, value) - /** Sets both numUserBlocks and numItemBlocks to the specific value. */ + /** + * Sets both numUserBlocks and numItemBlocks to the specific value. + * @group setParam + */ def setNumBlocks(value: Int): this.type = { setNumUserBlocks(value) setNumItemBlocks(value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d5a7bdafcb623..65f6627a0c351 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -44,7 +44,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress setRegParam(0.1) setMaxIter(100) + /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + + /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 324b1ba784387..b07a68269cc2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,22 +31,42 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { - /** param for the estimator to be cross-validated */ + /** + * param for the estimator to be cross-validated + * @group param + */ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + + /** @group getParam */ def getEstimator: Estimator[_] = get(estimator) - /** param for estimator param maps */ + /** + * param for estimator param maps + * @group param + */ val estimatorParamMaps: Param[Array[ParamMap]] = new Param(this, "estimatorParamMaps", "param maps for the estimator") + + /** @group getParam */ def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) - /** param for the evaluator for selection */ + /** + * param for the evaluator for selection + * @group param + */ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + + /** @group getParam */ def getEvaluator: Evaluator = get(evaluator) - /** param for number of folds for cross validation */ + /** + * param for number of folds for cross validation + * @group param + */ val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + + /** @group getParam */ def getNumFolds: Int = get(numFolds) } @@ -59,9 +79,16 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP private val f2jBLAS = new F2jBLAS + /** @group setParam */ def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { @@ -81,6 +108,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) @@ -88,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP metrics(i) += metric i += 1 } + validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 348c1e8760a66..35a0db76f3a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.classification +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} /** * :: Experimental :: @@ -60,16 +60,10 @@ private[mllib] object ClassificationModel { /** * Helper method for loading GLM classification model metadata. - * - * @param modelClass String name for model class (used for error messages) * @return (numFeatures, numClasses) */ - def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = { - metadata.select("numFeatures", "numClasses").take(1)(0) match { - case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses) - case _ => throw new Exception(s"$modelClass unable to load" + - s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}") - } + def getNumFeaturesClasses(metadata: JValue): (Int, Int) = { + implicit val formats = DefaultFormats + ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int]) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 9a391bfff76a3..420d6e2861934 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -173,8 +173,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val (numFeatures, numClasses) = - ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) // numFeatures, numClasses, weights are checked in model initialization val model = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index d9ce2822dd391..dd7a9469d5041 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -18,15 +18,16 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkContext, SparkException, Logging} +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} - /** * Model for Naive Bayes Classifiers. * @@ -78,7 +79,7 @@ class NaiveBayesModel private[mllib] ( object NaiveBayesModel extends Loader[NaiveBayesModel] { - import Loader._ + import org.apache.spark.mllib.util.Loader._ private object SaveLoadV1_0 { @@ -95,13 +96,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { import sqlContext.implicits._ // Create JSON metadata. - val metadataRDD = - sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1) - .toDataFrame("class", "version", "numFeatures", "numClasses") - metadataRDD.toJSON.saveAsTextFile(metadataPath(path)) + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) // Create Parquet data. - val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF dataRDD.saveAsParquetFile(dataPath(path)) } @@ -126,8 +127,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val (numFeatures, numClasses) = - ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val model = SaveLoadV1_0.load(sc, path) assert(model.pi.size == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 24d31e62ba500..cfc7f868a02f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -23,10 +23,9 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD - /** * Model for Support Vector Machines (SVMs). * @@ -97,8 +96,7 @@ object SVMModel extends Loader[SVMModel] { val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val (numFeatures, numClasses) = - ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path) + val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata) val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0) val model = new SVMModel(data.weights, data.intercept) assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 8d600572ed7f3..0a358f2e4f39c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -17,10 +17,13 @@ package org.apache.spark.mllib.classification.impl +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Loader -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{Row, SQLContext} /** * Helper class for import/export of GLM classification models. @@ -52,16 +55,14 @@ private[classification] object GLMClassificationModel { import sqlContext.implicits._ // Create JSON metadata. - val metadataRDD = - sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1) - .toDataFrame("class", "version", "numFeatures", "numClasses") - metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. val data = Data(weights, intercept, threshold) - val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) - // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(Seq(data), 1).toDF.saveAsParquetFile(Loader.dataPath(path)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a3e40200bc063..59a79e5c6a4ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -272,7 +272,7 @@ class Word2Vec extends Serializable with Logging { def hasNext: Boolean = iter.hasNext def next(): Array[Int] = { - var sentence = new ArrayBuffer[Int] + val sentence = ArrayBuilder.make[Int] var sentenceLength = 0 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { val word = bcVocabHash.value.get(iter.next()) @@ -283,7 +283,7 @@ class Word2Vec extends Serializable with Logging { case None => } } - sentence.toArray + sentence.result() } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 16979c9ed43ca..c399496568bfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -22,6 +22,9 @@ import java.lang.{Integer => JavaInteger} import org.apache.hadoop.fs.Path import org.jblas.DoubleMatrix +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} @@ -153,7 +156,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { - val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path) + val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, formatVersion) match { case (className, "1.0") if className == classNameV1_0 => @@ -181,19 +184,20 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { val sc = model.userFeatures.sparkContext val sqlContext = new SQLContext(sc) import sqlContext.implicits._ - val metadata = (thisClassName, thisFormatVersion, model.rank) - val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank") - metadataRDD.toJSON.saveAsTextFile(metadataPath(path)) - model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path)) - model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path)) + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) + model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) } def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + implicit val formats = DefaultFormats val sqlContext = new SQLContext(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - val rank = metadata.select("rank").first().getInt(0) + val rank = (metadata \ "rank").extract[Int] val userFeatures = sqlContext.parquetFile(userPath(path)) .map { case Row(id: Int, features: Seq[Double]) => (id, features.toArray) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 1159e59fff5f6..e8b03816573cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -58,7 +58,7 @@ object LassoModel extends Loader[LassoModel] { val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val numFeatures = RegressionModel.getNumFeatures(metadata) val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) new LassoModel(data.weights, data.intercept) case _ => throw new Exception( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 0136dcfdceaef..6fa7ad52a5b33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -58,7 +58,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val numFeatures = RegressionModel.getNumFeatures(metadata) val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) new LinearRegressionModel(data.weights, data.intercept) case _ => throw new Exception( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 843e59bdfbdd2..214ac4d0ed7dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.regression +import org.json4s.{DefaultFormats, JValue} + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.util.Loader import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} @Experimental trait RegressionModel extends Serializable { @@ -55,16 +55,10 @@ private[mllib] object RegressionModel { /** * Helper method for loading GLM regression model metadata. - * - * @param modelClass String name for model class (used for error messages) * @return numFeatures */ - def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = { - metadata.select("numFeatures").take(1)(0) match { - case Row(nFeatures: Int) => nFeatures - case _ => throw new Exception(s"$modelClass unable to load" + - s" numFeatures from metadata: ${Loader.metadataPath(path)}") - } + def getNumFeatures(metadata: JValue): Int = { + implicit val formats = DefaultFormats + (metadata \ "numFeatures").extract[Int] } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index f2a5f1db1ece6..8838ca8c14718 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -59,7 +59,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel" (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path) + val numFeatures = RegressionModel.getNumFeatures(metadata) val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures) new RidgeRegressionModel(data.weights, data.intercept) case _ => throw new Exception( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 838100e949ec2..7b27aaa322b00 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -17,6 +17,9 @@ package org.apache.spark.mllib.regression.impl +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Loader @@ -48,14 +51,14 @@ private[regression] object GLMRegressionModel { import sqlContext.implicits._ // Create JSON metadata. - val metadataRDD = - sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1) - .toDataFrame("class", "version", "numFeatures") - metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) + val metadata = compact(render( + ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~ + ("numFeatures" -> weights.size))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. val data = Data(weights, intercept) - val dataRDD: DataFrame = sc.parallelize(Seq(data), 1) + val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF // TODO: repartition with 1 partition after SPARK-5532 gets fixed dataRDD.saveAsParquetFile(Loader.dataPath(path)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b3e8ed9af8c51..f1f85994e61b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import scala.collection.mutable.ArrayBuilder +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.SparkContext._ - /** * :: Experimental :: @@ -1140,7 +1136,7 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) // iterate `valueCount` to find splits - val splits = new ArrayBuffer[Double] + val splitsBuilder = ArrayBuilder.make[Double] var index = 1 // currentCount: sum of counts of values that have been visited var currentCount = valueCounts(0)._2 @@ -1158,13 +1154,13 @@ object DecisionTree extends Serializable with Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splits.append(valueCounts(index - 1)._1) + splitsBuilder += valueCounts(index - 1)._1 targetCount += stride } index += 1 } - splits.toArray + splitsBuilder.result() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 89ecf3773dd77..5dac62b0c42f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -19,6 +19,10 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -184,16 +188,16 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] { import sqlContext.implicits._ // Create JSON metadata. - val metadataRDD = sc.parallelize( - Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1) - .toDataFrame("class", "version", "algo", "numNodes") - metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. val nodes = model.topNode.subtreeIterator.toSeq val dataRDD: DataFrame = sc.parallelize(nodes) .map(NodeData.apply(0, _)) - .toDataFrame + .toDF dataRDD.saveAsParquetFile(Loader.dataPath(path)) } @@ -269,20 +273,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] { } override def load(sc: SparkContext, path: String): DecisionTreeModel = { + implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) - val (algo: String, numNodes: Int) = try { - val algo_numNodes = metadata.select("algo", "numNodes").collect() - assert(algo_numNodes.length == 1) - algo_numNodes(0) match { - case Row(a: String, n: Int) => (a, n) - } - } catch { - // Catch both Error and Exception since the checks above can throw either. - case e: Throwable => - throw new Exception( - s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}." - + s" Error message: ${e.getMessage}") - } + val algo = (metadata \ "algo").extract[String] + val numNodes = (metadata \ "numNodes").extract[Int] val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 23bd46baabf65..e507f247cca76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -20,18 +20,20 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Algo +import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} - +import org.apache.spark.sql.SQLContext /** * :: Experimental :: @@ -59,11 +61,11 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis object RandomForestModel extends Loader[RandomForestModel] { override def load(sc: SparkContext, path: String): RandomForestModel = { - val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path) + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path) + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) assert(metadata.treeWeights.forall(_ == 1.0)) val trees = TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) @@ -110,11 +112,11 @@ class GradientBoostedTreesModel( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { - val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path) + val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => - val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path) + val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) assert(metadata.combiningStrategy == Sum.toString) val trees = TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) @@ -252,7 +254,7 @@ private[tree] object TreeEnsembleModel { object SaveLoadV1_0 { - import DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} + import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees} def thisFormatVersion = "1.0" @@ -276,38 +278,27 @@ private[tree] object TreeEnsembleModel { import sqlContext.implicits._ // Create JSON metadata. - val metadata = Metadata(model.algo.toString, model.trees(0).algo.toString, + implicit val format = DefaultFormats + val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, model.combiningStrategy.toString, model.treeWeights) - val metadataRDD = sc.parallelize(Seq((className, thisFormatVersion, metadata)), 1) - .toDataFrame("class", "version", "metadata") - metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path)) + val metadata = compact(render( + ("class" -> className) ~ ("version" -> thisFormatVersion) ~ + ("metadata" -> Extraction.decompose(ensembleMetadata)))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) // Create Parquet data. val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) - }.toDataFrame + }.toDF dataRDD.saveAsParquetFile(Loader.dataPath(path)) } /** - * Read metadata from the loaded metadata DataFrame. - * @param path Path for loading data, used for debug messages. + * Read metadata from the loaded JSON metadata. */ - def readMetadata(metadata: DataFrame, path: String): Metadata = { - try { - // We rely on the try-catch for schema checking rather than creating a schema just for this. - val metadataArray = metadata.select("metadata.algo", "metadata.treeAlgo", - "metadata.combiningStrategy", "metadata.treeWeights").collect() - assert(metadataArray.size == 1) - Metadata(metadataArray(0).getString(0), metadataArray(0).getString(1), - metadataArray(0).getString(2), metadataArray(0).getAs[Seq[Double]](3).toArray) - } catch { - // Catch both Error and Exception since the checks above can throw either. - case e: Throwable => - throw new Exception( - s"Unable to load TreeEnsembleModel metadata from: ${Loader.metadataPath(path)}." - + s" Error message: ${e.getMessage}") - } + def readMetadata(metadata: JValue): Metadata = { + implicit val formats = DefaultFormats + (metadata \ "metadata").extract[Metadata] } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index f7cba6c6cb628..308f7f3578e21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import java.util.StringTokenizer -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.{ArrayBuilder, ListBuffer} import org.apache.spark.SparkException @@ -51,7 +51,7 @@ private[mllib] object NumericParser { } private def parseArray(tokenizer: StringTokenizer): Array[Double] = { - val values = ArrayBuffer.empty[Double] + val values = ArrayBuilder.make[Double] var parsing = true var allowComma = false var token: String = null @@ -67,14 +67,14 @@ private[mllib] object NumericParser { } } else { // expecting a number - values.append(parseDouble(token)) + values += parseDouble(token) allowComma = true } } if (parsing) { throw new SparkException(s"An array must end with ']'.") } - values.toArray + values.result() } private def parseTuple(tokenizer: StringTokenizer): Seq[_] = { @@ -114,7 +114,7 @@ private[mllib] object NumericParser { try { java.lang.Double.parseDouble(s) } catch { - case e: Throwable => + case e: NumberFormatException => throw new SparkException(s"Cannot parse a double from: $s", e) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 56b77a7d12e83..4458340497f0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.util import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{DataType, StructType, StructField} - +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** * :: DeveloperApi :: @@ -120,20 +120,11 @@ private[mllib] object Loader { * Load metadata from the given path. * @return (class name, version, metadata) */ - def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = { - val sqlContext = new SQLContext(sc) - val metadata = sqlContext.jsonFile(metadataPath(path)) - val (clazz, version) = try { - val metadataArray = metadata.select("class", "version").take(1) - assert(metadataArray.size == 1) - metadataArray(0) match { - case Row(clazz: String, version: String) => (clazz, version) - } - } catch { - case e: Exception => - throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}") - } + def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = { + implicit val formats = DefaultFormats + val metadata = parse(sc.textFile(metadataPath(path)).first()) + val clazz = (metadata \ "class").extract[String] + val version = (metadata \ "version").extract[String] (clazz, version, metadata) } - } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index cb7d57de35c34..b118a8dcf1363 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -358,8 +358,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) val alpha = als.getAlpha - val model = als.fit(training) - val predictions = model.transform(test) + val model = als.fit(training.toDF) + val predictions = model.transform(test.toDF) .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) diff --git a/pom.xml b/pom.xml index 56e37d42265c0..6810d71be4230 100644 --- a/pom.xml +++ b/pom.xml @@ -619,19 +619,6 @@ 2.2.1 test - - org.easymock - easymockclassextension - 3.1 - test - - - - asm - asm - 3.3.1 - test - org.mockito mockito-all @@ -1096,6 +1083,12 @@ scala-maven-plugin 3.2.0 + + eclipse-add-source + + add-source + + scala-compile-first process-resources diff --git a/python/docs/conf.py b/python/docs/conf.py index b00dce95d65b4..cbbf7ffb08992 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -97,6 +97,10 @@ # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False +# -- Options for autodoc -------------------------------------------------- + +# Look at the first line of the docstring for function and method signatures. +autodoc_docstring_signature = True # -- Options for HTML output ---------------------------------------------- diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 80c6f02a9df41..e03379e521a07 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -16,3 +16,11 @@ pyspark.sql.types module :members: :undoc-members: :show-inheritance: + + +pyspark.sql.functions module +------------------------ +.. automodule:: pyspark.sql.functions + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6bd2aa8e47837..b6de7493d7523 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,7 +15,7 @@ # limitations under the License. # -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import inherit_doc, keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ HasRegParam @@ -32,22 +32,46 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors - >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \ - Row(label=1.0, features=Vectors.dense(1.0)), \ - Row(label=0.0, features=Vectors.sparse(1, [], []))])) - >>> lr = LogisticRegression() \ - .setMaxIter(5) \ - .setRegParam(0.01) - >>> model = lr.fit(dataset) - >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) + >>> df = sc.parallelize([ + ... Row(label=1.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> model = lr.fit(df) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> print model.transform(test0).head().prediction 0.0 - >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> print model.transform(test1).head().prediction 1.0 + >>> lr.setParams("vector") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ _java_class = "org.apache.spark.ml.classification.LogisticRegression" + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + """ + super(LogisticRegression, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, regParam=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, regParam=0.1) + Sets params for logistic regression. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + def _create_model(self, java_model): return LogisticRegressionModel(java_model) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index e088acd0ca82d..f1ddbb478dd9c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -16,7 +16,7 @@ # from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import inherit_doc, keyword_only from pyspark.ml.wrapper import JavaTransformer __all__ = ['Tokenizer', 'HashingTF'] @@ -29,18 +29,45 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): splits it by white spaces. >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")])) - >>> tokenizer = Tokenizer() \ - .setInputCol("text") \ - .setOutputCol("words") - >>> print tokenizer.transform(dataset).head() + >>> df = sc.parallelize([Row(text="a b c")]).toDF() + >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") + >>> print tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() + >>> # Change a parameter. + >>> print tokenizer.setParams(outputCol="tokens").transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Temporarily modify a parameter. + >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> print tokenizer.transform(df).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + >>> # Must use keyword arguments to specify params. + >>> tokenizer.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. """ _java_class = "org.apache.spark.ml.feature.Tokenizer" + @keyword_only + def __init__(self, inputCol="input", outputCol="output"): + """ + __init__(self, inputCol="input", outputCol="output") + """ + super(Tokenizer, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol="input", outputCol="output"): + """ + setParams(self, inputCol="input", outputCol="output") + Sets params for this Tokenizer. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): @@ -49,20 +76,37 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): hashing trick. >>> from pyspark.sql import Row - >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])])) - >>> hashingTF = HashingTF() \ - .setNumFeatures(10) \ - .setInputCol("words") \ - .setOutputCol("features") - >>> print hashingTF.transform(dataset).head().features + >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() + >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + >>> print hashingTF.transform(df).head().features + (10,[7,8,9],[1.0,1.0,1.0]) + >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs (10,[7,8,9],[1.0,1.0,1.0]) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(dataset, params).head().vector + >>> print hashingTF.transform(df, params).head().vector (5,[2,3,4],[1.0,1.0,1.0]) """ _java_class = "org.apache.spark.ml.feature.HashingTF" + @keyword_only + def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + """ + super(HashingTF, self).__init__() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + """ + setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + Sets params for this HashingTF. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5566792cead48..e3a53dd780c4c 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -80,3 +80,11 @@ def _dummy(): dummy = Params() dummy.uid = "undefined" return dummy + + def _set_params(self, **kwargs): + """ + Sets params. + """ + for param, value in kwargs.iteritems(): + self.paramMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2d239f8c802a0..18d8a58f357bd 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -18,7 +18,7 @@ from abc import ABCMeta, abstractmethod from pyspark.ml.param import Param, Params -from pyspark.ml.util import inherit_doc +from pyspark.ml.util import inherit_doc, keyword_only __all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] @@ -89,10 +89,16 @@ class Pipeline(Estimator): identity transformer. """ - def __init__(self): + @keyword_only + def __init__(self, stages=[]): + """ + __init__(self, stages=[]) + """ super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) def setStages(self, value): """ @@ -110,6 +116,15 @@ def getStages(self): if self.stages in self.paramMap: return self.paramMap[self.stages] + @keyword_only + def setParams(self, stages=[]): + """ + setParams(self, stages=[]) + Sets params for Pipeline. + """ + kwargs = self.setParams._input_kwargs + return self._set_params(**kwargs) + def fit(self, dataset, params={}): paramMap = self._merge_params(params) stages = paramMap[self.stages] diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b1caa84b6306a..81d3f0882b8a9 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,6 +15,7 @@ # limitations under the License. # +from functools import wraps import uuid @@ -32,6 +33,20 @@ def inherit_doc(cls): return cls +def keyword_only(func): + """ + A decorator that forces keyword arguments in the wrapped method + and saves actual input keyword arguments in `_input_kwargs`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise TypeError("Method %s forces keyword arguments." % func.__name__) + wrapper._input_kwargs = kwargs + return func(*args, **kwargs) + return wrapper + + class Identifiable(object): """ Object with a unique ID. diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 49e5c9d58e5db..06207a076eece 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -335,7 +335,7 @@ def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema() + schema = srdd.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) vectors = srdd.map(lambda p: p.features).collect() diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 0a5ba00393aab..b9ffd6945ea7e 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -34,9 +34,8 @@ from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.types import Row -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD +from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', - 'Dsl', ] diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index db4bcbece2c1b..7683c1b4dfa4e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -38,6 +38,25 @@ __all__ = ["SQLContext", "HiveContext"] +def _monkey_patch_RDD(sqlCtx): + def toDF(self, schema=None, sampleRatio=None): + """ + Convert current :class:`RDD` into a :class:`DataFrame` + + This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)` + + :param schema: a StructType or list of names of columns + :param samplingRatio: the sample ratio of rows used for inferring + :return: a DataFrame + + >>> rdd.toDF().collect() + [Row(name=u'Alice', age=1)] + """ + return sqlCtx.createDataFrame(self, schema, sampleRatio) + + RDD.toDF = toDF + + class SQLContext(object): """Main entry point for Spark SQL functionality. @@ -49,15 +68,20 @@ class SQLContext(object): def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. + It will add a method called `toDF` to :class:`RDD`, which could be + used to convert an RDD into a DataFrame, it's a shorthand for + :func:`SQLContext.createDataFrame`. + :param sparkContext: The SparkContext to wrap. :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. >>> from datetime import datetime + >>> sqlCtx = SQLContext(sc) >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> df = sqlCtx.createDataFrame(allTypes) + >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() @@ -70,6 +94,7 @@ def __init__(self, sparkContext, sqlContext=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._scala_SQLContext = sqlContext + _monkey_patch_RDD(self) @property def _ssql_ctx(self): @@ -442,7 +467,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema) >>> sqlCtx.registerRDDAsTable(df3, "table2") >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " @@ -495,7 +520,7 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> df3 = sqlCtx.jsonRDD(json, df1.schema) >>> sqlCtx.registerRDDAsTable(df3, "table2") >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " @@ -621,6 +646,40 @@ def table(self, tableName): """ return DataFrame(self._ssql_ctx.table(tableName), self) + def tables(self, dbName=None): + """Returns a DataFrame containing names of tables in the given database. + + If `dbName` is not specified, the current database will be used. + + The returned DataFrame has two columns, tableName and isTemporary + (a column with BooleanType indicating if a table is a temporary one or not). + + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.tables() + >>> df2.filter("tableName = 'table1'").first() + Row(tableName=u'table1', isTemporary=True) + """ + if dbName is None: + return DataFrame(self._ssql_ctx.tables(), self) + else: + return DataFrame(self._ssql_ctx.tables(dbName), self) + + def tableNames(self, dbName=None): + """Returns a list of names of tables in the database `dbName`. + + If `dbName` is not specified, the current database will be used. + + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> "table1" in sqlCtx.tableNames() + True + >>> "table1" in sqlCtx.tableNames("db") + True + """ + if dbName is None: + return [name for name in self._ssql_ctx.tableNames()] + else: + return [name for name in self._ssql_ctx.tableNames(dbName)] + def cacheTable(self, tableName): """Caches the specified table in-memory.""" self._ssql_ctx.cacheTable(tableName) @@ -766,7 +825,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) - globs['df'] = sqlCtx.createDataFrame(rdd) + _monkey_patch_RDD(sqlCtx) + globs['df'] = rdd.toDF() jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3eb56ed74cc6f..1438fe5285cc5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -21,21 +21,19 @@ import random import os from tempfile import NamedTemporaryFile -from itertools import imap from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD, _prepare_for_python_RDD -from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ - UTF8Deserializer +from pyspark.rdd import RDD +from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import * from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"] class DataFrame(object): @@ -76,6 +74,7 @@ def __init__(self, jdf, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx and sql_ctx._sc self.is_cached = False + self._schema = None # initialized lazily @property def rdd(self): @@ -86,7 +85,7 @@ def rdd(self): if not hasattr(self, '_lazy_rdd'): jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) - schema = self.schema() + schema = self.schema def applySchema(it): cls = _create_cls(schema) @@ -149,7 +148,7 @@ def insertInto(self, tableName, overwrite=False): def _java_save_mode(self, mode): """Returns the Java save mode based on the Python save mode represented by a string. """ - jSaveMode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode + jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode jmode = jSaveMode.ErrorIfExists mode = mode.lower() if mode == "append": @@ -216,14 +215,17 @@ def save(self, path=None, source=None, mode="append", **options): self._sc._gateway._gateway_client) self._jdf.save(source, jmode, joptions) + @property def schema(self): """Returns the schema of this DataFrame (represented by a L{StructType}). - >>> df.schema() + >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ - return _parse_datatype_json_string(self._jdf.schema().json()) + if self._schema is None: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + return self._schema def printSchema(self): """Prints out the schema in the tree format. @@ -284,7 +286,7 @@ def collect(self): with open(tempFile.name, 'rb') as tempFile: rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) os.unlink(tempFile.name) - cls = _create_cls(self.schema()) + cls = _create_cls(self.schema) return [cls(r) for r in rs] def limit(self, num): @@ -310,14 +312,26 @@ def take(self, num): return self.limit(num).collect() def map(self, f): - """ Return a new RDD by applying a function to each Row, it's a - shorthand for df.rdd.map() + """ Return a new RDD by applying a function to each Row + + It's a shorthand for df.rdd.map() >>> df.map(lambda p: p.name).collect() [u'Alice', u'Bob'] """ return self.rdd.map(f) + def flatMap(self, f): + """ Return a new RDD by first applying a function to all elements of this, + and then flattening the results. + + It's a shorthand for df.rdd.flatMap() + + >>> df.flatMap(lambda p: p.name).collect() + [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] + """ + return self.rdd.flatMap(f) + def mapPartitions(self, f, preservesPartitioning=False): """ Return a new RDD by applying a function to each partition. @@ -378,21 +392,6 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) - # def takeSample(self, withReplacement, num, seed=None): - # """Return a fixed-size sampled subset of this DataFrame. - # - # >>> df = sqlCtx.inferSchema(rdd) - # >>> df.takeSample(False, 2, 97) - # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - # """ - # seed = seed if seed is not None else random.randint(0, sys.maxint) - # with SCCallSiteSync(self.context) as css: - # bytesInJava = self._jdf \ - # .takeSampleToPython(withReplacement, num, long(seed)) \ - # .iterator() - # cls = _create_cls(self.schema()) - # return map(cls, self._collect_iterator_through_file(bytesInJava)) - @property def dtypes(self): """Return all column names and their data types as a list. @@ -400,7 +399,7 @@ def dtypes(self): >>> df.dtypes [('age', 'int'), ('name', 'string')] """ - return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields] + return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property def columns(self): @@ -409,7 +408,7 @@ def columns(self): >>> df.columns [u'age', u'name'] """ - return [f.name for f in self.schema().fields] + return [f.name for f in self.schema.fields] def join(self, other, joinExprs=None, joinType=None): """ @@ -586,8 +585,8 @@ def agg(self, *exprs): >>> df.agg({"age": "max"}).collect() [Row(MAX(age#0)=5)] - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.min(df.age)).collect() + >>> from pyspark.sql import functions as F + >>> df.agg(F.min(df.age)).collect() [Row(MIN(age#0)=2)] """ return self.groupBy().agg(*exprs) @@ -616,18 +615,18 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - def addColumn(self, colName, col): + def withColumn(self, colName, col): """ Return a new :class:`DataFrame` by adding a column. - >>> df.addColumn('age2', df.age + 2).collect() + >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ return self.select('*', col.alias(colName)) - def renameColumn(self, existing, new): + def withColumnRenamed(self, existing, new): """ Rename an existing column to a new name - >>> df.renameColumn('age', 'age2').collect() + >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ cols = [Column(_to_java_column(c), self.sql_ctx).alias(new) @@ -635,11 +634,11 @@ def renameColumn(self, existing, new): for c in self.columns] return self.select(*cols) - def to_pandas(self): + def toPandas(self): """ Collect all the rows and return a `pandas.DataFrame`. - >>> df.to_pandas() # doctest: +SKIP + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob @@ -687,10 +686,11 @@ def agg(self, *exprs): name to aggregate methods. >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"age": "max"}).collect() - [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] - >>> from pyspark.sql import Dsl - >>> gdf.agg(Dsl.min(df.age)).collect() + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] """ assert exprs, "exprs should not be empty" @@ -742,12 +742,12 @@ def sum(self): def _create_column_from_literal(literal): sc = SparkContext._active_spark_context - return sc._jvm.Dsl.lit(literal) + return sc._jvm.functions.lit(literal) def _create_column_from_name(name): sc = SparkContext._active_spark_context - return sc._jvm.Dsl.col(name) + return sc._jvm.functions.col(name) def _to_java_column(col): @@ -767,9 +767,9 @@ def _(self): return _ -def _dsl_op(name, doc=''): +def _func_op(name, doc=''): def _(self): - jc = getattr(self._sc._jvm.Dsl, name)(self._jc) + jc = getattr(self._sc._jvm.functions, name)(self._jc) return Column(jc, self.sql_ctx) _.__doc__ = doc return _ @@ -818,7 +818,7 @@ def __init__(self, jc, sql_ctx=None): super(Column, self).__init__(jc, sql_ctx) # arithmetic operators - __neg__ = _dsl_op("negate") + __neg__ = _func_op("negate") __add__ = _bin_op("plus") __sub__ = _bin_op("minus") __mul__ = _bin_op("multiply") @@ -842,7 +842,7 @@ def __init__(self, jc, sql_ctx=None): # so use bitwise operators as boolean operators __and__ = _bin_op('and') __or__ = _bin_op('or') - __invert__ = _dsl_op('not') + __invert__ = _func_op('not') __rand__ = _bin_op("and") __ror__ = _bin_op("or") @@ -920,11 +920,11 @@ def __repr__(self): else: return 'Column<%s>' % self._jdf.toString() - def to_pandas(self): + def toPandas(self): """ Return a pandas.Series from the column - >>> df.age.to_pandas() # doctest: +SKIP + >>> df.age.toPandas() # doctest: +SKIP 0 2 1 5 dtype: int64 @@ -934,123 +934,6 @@ def to_pandas(self): return pd.Series(data) -def _aggregate_func(name, doc=""): - """ Create a function for aggregator by name""" - def _(col): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col)) - return Column(jc) - _.__name__ = name - _.__doc__ = doc - return staticmethod(_) - - -class UserDefinedFunction(object): - def __init__(self, func, returnType): - self.func = func - self.returnType = returnType - self._broadcast = None - self._judf = self._create_judf() - - def _create_judf(self): - f = self.func # put it in closure `func` - func = lambda _, it: imap(lambda x: f(*x), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - sc = SparkContext._active_spark_context - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, - includes, sc.pythonExec, broadcast_vars, - sc._javaAccumulator, jdt) - return judf - - def __del__(self): - if self._broadcast is not None: - self._broadcast.unpersist() - self._broadcast = None - - def __call__(self, *cols): - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) - return Column(jc) - - -class Dsl(object): - """ - A collections of builtin aggregators - """ - DSLS = { - 'lit': 'Creates a :class:`Column` of literal value.', - 'col': 'Returns a :class:`Column` based on the given column name.', - 'column': 'Returns a :class:`Column` based on the given column name.', - 'upper': 'Converts a string expression to upper case.', - 'lower': 'Converts a string expression to upper case.', - 'sqrt': 'Computes the square root of the specified float value.', - 'abs': 'Computes the absolutle value.', - - 'max': 'Aggregate function: returns the maximum value of the expression in a group.', - 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', - 'count': 'Aggregate function: returns the number of items in a group.', - 'sum': 'Aggregate function: returns the sum of all values in the expression.', - 'avg': 'Aggregate function: returns the average of the values in a group.', - 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', - } - - for _name, _doc in DSLS.items(): - locals()[_name] = _aggregate_func(_name, _doc) - del _name, _doc - - @staticmethod - def countDistinct(col, *cols): - """ Return a new Column for distinct count of (col, *cols) - - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() - [Row(c=2)] - - >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect() - [Row(c=2)] - """ - sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.PythonUtils.toSeq(jcols)) - return Column(jc) - - @staticmethod - def approxCountDistinct(col, rsd=None): - """ Return a new Column for approxiate distinct count of (col, *cols) - - >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] - """ - sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) - - @staticmethod - def udf(f, returnType=StringType()): - """Create a user defined function (UDF) - - >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) - >>> df.select(slen(df.name).alias('slen')).collect() - [Row(slen=5), Row(slen=3)] - """ - return UserDefinedFunction(f, returnType) - - def _test(): import doctest from pyspark.context import SparkContext @@ -1059,11 +942,9 @@ def _test(): globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc - globs['sqlCtx'] = sqlCtx = SQLContext(sc) - rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]) - rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]) - globs['df'] = sqlCtx.inferSchema(rdd2) - globs['df2'] = sqlCtx.inferSchema(rdd3) + globs['sqlCtx'] = SQLContext(sc) + globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py new file mode 100644 index 0000000000000..39aa550eeb5ad --- /dev/null +++ b/python/pyspark/sql/functions.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin functions +""" + +from itertools import imap + +from py4j.java_collections import ListConverter + +from pyspark import SparkContext +from pyspark.rdd import _prepare_for_python_RDD +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql.types import StringType +from pyspark.sql.dataframe import Column, _to_java_column + + +__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] + + +def _create_function(name, doc=""): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(_to_java_column(col)) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +_functions = { + 'lit': 'Creates a :class:`Column` of literal value.', + 'col': 'Returns a :class:`Column` based on the given column name.', + 'column': 'Returns a :class:`Column` based on the given column name.', + 'upper': 'Converts a string expression to upper case.', + 'lower': 'Converts a string expression to upper case.', + 'sqrt': 'Computes the square root of the specified float value.', + 'abs': 'Computes the absolutle value.', + + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + + +for _name, _doc in _functions.items(): + globals()[_name] = _create_function(_name, _doc) +del _name, _doc +__all__ += _functions.keys() + + +def countDistinct(col, *cols): + """ Return a new Column for distinct count of `col` or `cols` + + >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() + [Row(c=2)] + + >>> df.agg(countDistinct("age", "name").alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) + jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +def approxCountDistinct(col, rsd=None): + """ Return a new Column for approximate distinct count of `col` + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + +class UserDefinedFunction(object): + """ + User defined function in Python + """ + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + +def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.dataframe + globs = pyspark.sql.dataframe.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlCtx'] = SQLContext(sc) + globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + (failure_count, test_count) = doctest.testmod( + pyspark.sql.dataframe, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 43e5c3a1b00fa..aa80bca34655d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -96,7 +96,7 @@ def setUpClass(cls): cls.sqlCtx = SQLContext(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] rdd = cls.sc.parallelize(cls.testData) - cls.df = cls.sqlCtx.createDataFrame(rdd) + cls.df = rdd.toDF() @classmethod def tearDownClass(cls): @@ -138,7 +138,7 @@ def test_basic_functions(self): df = self.sqlCtx.jsonRDD(rdd) df.count() df.collect() - df.schema() + df.schema # cache and checkpoint self.assertFalse(df.is_cached) @@ -155,11 +155,11 @@ def test_basic_functions(self): def test_apply_schema_to_row(self): df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema()) + df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema()) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): @@ -195,7 +195,7 @@ def test_infer_schema(self): self.assertEqual(1, result.head()[0]) df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) - self.assertEqual(df.schema(), df2.schema()) + self.assertEqual(df.schema, df2.schema) self.assertEqual({}, df2.map(lambda r: r.d).first()) self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) df2.registerTempTable("test2") @@ -204,8 +204,7 @@ def test_infer_schema(self): def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] - rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.sc.parallelize(d).toDF() k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -213,8 +212,7 @@ def test_struct_in_map(self): def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.createDataFrame(rdd) + df = self.sc.parallelize([row]).toDF() df.registerTempTable("test") row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) @@ -223,9 +221,8 @@ def test_convert_row_to_dict(self): def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df = self.sqlCtx.createDataFrame(rdd) - schema = df.schema() + df = self.sc.parallelize([row]).toDF() + schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) df.registerTempTable("labeled_point") @@ -238,15 +235,14 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = rdd.toDF(schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) - df0 = self.sqlCtx.createDataFrame(rdd) + df0 = self.sc.parallelize([row]).toDF() output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) @@ -280,10 +276,11 @@ def test_aggregator(self): self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - from pyspark.sql import Dsl - self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first())) - self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0]) + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) def test_save_and_load(self): df = self.df @@ -339,8 +336,7 @@ def setUpClass(cls): cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData) - cls.df = cls.sqlCtx.inferSchema(rdd) + cls.df = cls.sc.parallelize(cls.testData).toDF() @classmethod def tearDownClass(cls): diff --git a/python/run-tests b/python/run-tests index 58a26dd8ff088..a2c2f37a54eda 100755 --- a/python/run-tests +++ b/python/run-tests @@ -35,7 +35,7 @@ rm -rf metastore warehouse function run_test() { echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -67,6 +67,7 @@ function run_sql_tests() { run_test "pyspark/sql/types.py" run_test "pyspark/sql/context.py" run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/functions.py" run_test "pyspark/sql/tests.py" } diff --git a/repl/pom.xml b/repl/pom.xml index 3d4adf8fd5b03..b883344bf0ceb 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -33,8 +33,6 @@ repl - /usr/share/spark - root scala-2.10/src/main/scala scala-2.10/src/test/scala diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 0cf2de6d399b0..05faef8786d2c 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -137,7 +137,7 @@ private[repl] trait SparkILoopInit { command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") command("import sqlContext.sql") - command("import org.apache.spark.sql.Dsl._") + command("import org.apache.spark.sql.functions._") } } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 201f2672d5474..529914a2b6141 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,7 +262,7 @@ class ReplSuite extends FunSuite { |val sqlContext = new org.apache.spark.sql.SQLContext(sc) |import sqlContext.implicits._ |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect() + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF.collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 1bd2a6991404b..7a5e94da5cbf3 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -77,7 +77,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") command("import sqlContext.sql") - command("import org.apache.spark.sql.Dsl._") + command("import org.apache.spark.sql.functions._") } } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 89608bc41b71d..ec6d0b5a40ef2 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -129,8 +129,9 @@ case $option in mkdir -p "$SPARK_PID_DIR" if [ -f $pid ]; then - if kill -0 `cat $pid` > /dev/null 2>&1; then - echo $command running as process `cat $pid`. Stop it first. + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then + echo "$command running as process $TARGET_ID. Stop it first." exit 1 fi fi @@ -141,7 +142,7 @@ case $option in fi spark_rotate_log "$log" - echo starting $command, logging to $log + echo "starting $command, logging to $log" if [ $option == spark-submit ]; then source "$SPARK_HOME"/bin/utils.sh gatherSparkSubmitOpts "$@" @@ -154,7 +155,7 @@ case $option in echo $newpid > $pid sleep 2 # Check if the process has died; in that case we'll tail the log so the user can see - if ! kill -0 $newpid >/dev/null 2>&1; then + if [[ ! $(ps -p "$newpid" -o args=) =~ $command ]]; then echo "failed to launch $command:" tail -2 "$log" | sed 's/^/ /' echo "full log in $log" @@ -164,14 +165,15 @@ case $option in (stop) if [ -f $pid ]; then - if kill -0 `cat $pid` > /dev/null 2>&1; then - echo stopping $command - kill `cat $pid` + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then + echo "stopping $command" + kill "$TARGET_ID" else - echo no $command to stop + echo "no $command to stop" fi else - echo no $command to stop + echo "no $command to stop" fi ;; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5d9c331ca5178..11fd443733658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -122,6 +122,21 @@ trait ScalaReflection { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -144,21 +159,6 @@ trait ScalaReflection { schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) StructField(p.name.toString, dataType, nullable) }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d1fa106a2aa9..58a7003977c93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType} +import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -66,32 +66,82 @@ class Analyzer(catalog: Catalog, typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, - CheckResolution :: - CheckAggregation :: - Nil: _*), - Batch("AnalysisOperators", fixedPoint, - EliminateAnalysisOperators) + CheckResolution), + Batch("Remove SubQueries", fixedPoint, + EliminateSubQueries) ) /** * Makes sure all attributes and logical plans have been resolved. */ object CheckResolution extends Rule[LogicalPlan] { + def failAnalysis(msg: String) = { throw new AnalysisException(msg) } + def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformUp { - case p if p.expressions.exists(!_.resolved) => - val missing = p.expressions.filterNot(_.resolved).map(_.prettyString).mkString(",") - val from = p.inputSet.map(_.name).mkString("{", ", ", "}") - - throw new AnalysisException(s"Cannot resolve '$missing' given input columns $from") - case p if !p.resolved && p.childrenResolved => - throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}") - } match { - // As a backstop, use the root node to check that the entire plan tree is resolved. - case p if !p.resolved => - throw new AnalysisException(s"Unresolved operator in the query plan ${p.simpleString}") - case p => p + // We transform up and order the rules so as to catch the first possible failure instead + // of the result of cascading resolution failures. + plan.foreachUp { + case operator: LogicalPlan => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + val from = operator.inputSet.map(_.name).mkString(", ") + failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case b: BinaryExpression if !b.resolved => + failAnalysis( + s"invalid expression ${b.prettyString} " + + s"between ${b.left.simpleString} and ${b.right.simpleString}") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.contains(e) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.contains(e) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + val cleaned = aggregateExprs.map(_.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g, _) => g + }) + + cleaned.foreach(checkValidAggregateExpression) + + case o if o.children.nonEmpty && + !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => + val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") + val input = o.inputSet.map(_.prettyString).mkString(",") + + failAnalysis(s"resolved attributes $missingAttributes missing from $input") + + // Catch all + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case _ => // Analysis successful! + } } + + plan } } @@ -192,37 +242,6 @@ class Analyzer(catalog: Catalog, } } - /** - * Checks for non-aggregated attributes with aggregation - */ - object CheckAggregation extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan.transform { - case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => - def isValidAggregateExpression(expr: Expression): Boolean = expr match { - case _: AggregateExpression => true - case e: Attribute => groupingExprs.contains(e) - case e if groupingExprs.contains(e) => true - case e if e.references.isEmpty => true - case e => e.children.forall(isValidAggregateExpression) - } - - aggregateExprs.find { e => - !isValidAggregateExpression(e.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g: GetField, _) => g - }) - }.foreach { e => - throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") - } - - aggregatePlan - } - } - } - /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -230,7 +249,7 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) => i.copy( - table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias))) + table = EliminateSubQueries(catalog.lookupRelation(tableIdentifier, alias))) case UnresolvedRelation(tableIdentifier, alias) => catalog.lookupRelation(tableIdentifier, alias) } @@ -477,7 +496,7 @@ class Analyzer(catalog: Catalog, * only required to provide scoping information for attributes and can be removed once analysis is * complete. */ -object EliminateAnalysisOperators extends Rule[LogicalPlan] { +object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Subquery(_, child) => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index df8d03b86c533..f57eab24607f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -34,6 +34,12 @@ trait Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan + /** + * Returns tuples of (tableName, isTemporary) for all tables in the given database. + * isTemporary is a Boolean value indicates if a table is a temporary or not. + */ + def getTables(databaseName: Option[String]): Seq[(String, Boolean)] + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit def unregisterTable(tableIdentifier: Seq[String]): Unit @@ -101,6 +107,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { // properly qualified with this alias. alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) } + + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + tables.map { + case (name, _) => (name, true) + }.toSeq + } } /** @@ -137,6 +149,27 @@ trait OverrideCatalog extends Catalog { withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias)) } + abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + val dbName = if (!caseSensitive) { + if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None + } else { + databaseName + } + + val temporaryTables = overrides.filter { + // If a temporary table does not have an associated database, we should return its name. + case ((None, _), _) => true + // If a temporary table does have an associated database, we should return it if the database + // matches the given database name. + case ((db: Some[String], _), _) if db == dbName => true + case _ => false + }.map { + case ((_, tableName), _) => (tableName, true) + }.toSeq + + temporaryTables ++ super.getTables(databaseName) + } + override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { @@ -172,6 +205,10 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + throw new UnsupportedOperationException + } + def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f959a50564011..a7cd4124e56f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -152,7 +152,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def newInstance = this + override def newInstance() = this override def withNullability(newNullability: Boolean) = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 171845ad14e3e..a9ba0be596349 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -20,7 +20,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.Star protected class AttributeEquals(val a: Attribute) { - override def hashCode() = a.exprId.hashCode() + override def hashCode() = a match { + case ar: AttributeReference => ar.exprId.hashCode() + case a => a.hashCode() + } + override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId case (a1, a2) => a1 == a2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 43b6482c0171c..0983d274def3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -73,6 +73,25 @@ abstract class Generator extends Expression { } } +/** + * A generator that produces its output using the provided lambda function. + */ +case class UserDefinedGenerator( + schema: Seq[Attribute], + function: Row => TraversableOnce[Row], + children: Seq[Expression]) + extends Generator{ + + override protected def makeOutput(): Seq[Attribute] = schema + + override def eval(input: Row): TraversableOnce[Row] = { + val inputRow = new InterpretedProjection(children) + function(inputRow(input)) + } + + override def toString = s"UserDefinedGenerator(${children.mkString(",")})" +} + /** * Given an input array produces a sequence of rows for each value in the array. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index f77c56311cc8c..62c062be6d820 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -218,7 +218,7 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def exprId: ExprId = ??? override def eval(input: Row): EvaluatedType = ??? override def nullable: Boolean = ??? - override def dataType: DataType = ??? + override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0da081ed1a6e2..1a75fcf3545bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,6 +119,15 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) + if (a.outputSet -- p.references).nonEmpty => + Project( + projectList, + Aggregate( + groupingExpressions, + aggregateExpressions.filter(e => p.references.contains(e)), + child)) + // Eliminate unneeded attributes from either side of a Join. case Project(projectList, Join(left, right, joinType, condition)) => // Collect the list of all references required either above or to evaluate the condition. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index cfe2c7a39a17c..ccf5291219add 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression} /** * Transforms the input by forking and running the specified script. @@ -32,7 +32,9 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: LogicalPlan, - ioschema: ScriptInputOutputSchema) extends UnaryNode + ioschema: ScriptInputOutputSchema) extends UnaryNode { + override def references: AttributeSet = AttributeSet(input.flatMap(_.references)) +} /** * A placeholder for implementation specific input and output properties when passing data diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2013ae4f7bd13..e0930b056d5fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -46,6 +46,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { children.foreach(_.foreach(f)) } + /** + * Runs the given function recursively on [[children]] then on this node. + * @param f the function to be applied to each node in the tree. + */ + def foreachUp(f: BaseType => Unit): Unit = { + children.foreach(_.foreach(f)) + f(this) + } + /** * Returns a Seq containing the result of applying the given function to each * node in this tree in a preorder traversal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d0f547d187ecb..eee00e3f7ea76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -61,6 +61,7 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], arrayField1: Array[Int], + arrayField2: List[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], @@ -137,6 +138,10 @@ class ScalaReflectionSuite extends FunSuite { "arrayField1", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField2", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f011a5ff15ea9..e70c651e1486e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Literal, Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -108,24 +108,56 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { testRelation) } - test("throw errors for unresolved attributes during analysis") { - val e = intercept[AnalysisException] { - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) + def errorTest( + name: String, + plan: LogicalPlan, + errorMessages: Seq[String], + caseSensitive: Boolean = true) = { + test(name) { + val error = intercept[AnalysisException] { + if(caseSensitive) { + caseSensitiveAnalyze(plan) + } else { + caseInsensitiveAnalyze(plan) + } + } + + errorMessages.foreach(m => assert(error.getMessage contains m)) } - assert(e.getMessage().toLowerCase.contains("cannot resolve")) } - test("throw errors for unresolved plans during analysis") { - case class UnresolvedTestPlan() extends LeafNode { - override lazy val resolved = false - override def output = Nil - } - val e = intercept[AnalysisException] { - caseSensitiveAnalyze(UnresolvedTestPlan()) - } - assert(e.getMessage().toLowerCase.contains("unresolved")) + errorTest( + "unresolved attributes", + testRelation.select('abcd), + "cannot resolve" :: "abcd" :: Nil) + + errorTest( + "bad casts", + testRelation.select(Literal(1).cast(BinaryType).as('badCast)), + "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + + errorTest( + "non-boolean filters", + testRelation.where(Literal(1)), + "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + + errorTest( + "missing group by", + testRelation2.groupBy('a)('b), + "'b'" :: "group by" :: Nil + ) + + case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output = Nil } + errorTest( + "catch all unresolved plan", + UnresolvedTestPlan(), + "unresolved" :: Nil) + + test("divide should be casted into fractional types") { val testRelation2 = LocalRelation( AttributeReference("a", StringType)(), @@ -134,18 +166,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val expr0 = 'a / 2 - val expr1 = 'a / 'b - val expr2 = 'a / 'c - val expr3 = 'a / 'd - val expr4 = 'e / 'e - val plan = caseInsensitiveAnalyze(Project( - Alias(expr0, s"Analyzer($expr0)")() :: - Alias(expr1, s"Analyzer($expr1)")() :: - Alias(expr2, s"Analyzer($expr2)")() :: - Alias(expr3, s"Analyzer($expr3)")() :: - Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val plan = caseInsensitiveAnalyze( + testRelation2.select( + 'a / Literal(2) as 'div1, + 'a / 'b as 'div2, + 'a / 'c as 'div3, + 'a / 'd as 'div4, + 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 264a0eff37d34..72f06e26e05f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -30,7 +30,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index e22c62505860a..ef10c0aece716 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1158b5dfc6147..55c6766520a1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators -import org.apache.spark.sql.catalyst.expressions.Explode +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.expressions.{Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -32,12 +32,13 @@ class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, - PushPredicateThroughGenerate) :: Nil + PushPredicateThroughGenerate, + ColumnPruning) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -58,6 +59,38 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("column pruning for group") { + val originalQuery = + testRelation + .groupBy('a)('a, Count('b)) + .select('a) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a) + .select('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val originalQuery = + testRelation + .groupBy('a)('a as 'c, Count('b)) + .select('c) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c) + .select('c).analyze + + comparePlans(optimized, correctAnswer) + } + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -351,7 +384,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize(originalQuery.analyze) - comparePlans(analysis.EliminateAnalysisOperators(originalQuery.analyze), optimized) + comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) } test("joins: conjunctive predicates") { @@ -370,7 +403,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } test("joins: conjunctive predicates #2") { @@ -389,7 +422,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } test("joins: conjunctive predicates #3") { @@ -412,7 +445,7 @@ class FilterPushdownSuite extends PlanTest { condition = Some("z.a".attr === "x.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) } val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index da912ab382179..233e329cb2038 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -34,7 +34,7 @@ class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("ConstantFolding", Once, ConstantFolding, BooleanSimplification, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index dfef87bd9133d..a54751dfa9a12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -29,7 +29,7 @@ class UnionPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateAnalysisOperators) :: + EliminateSubQueries) :: Batch("Union Pushdown", Once, UnionPushdown) :: Nil } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java rename to sql/core/src/main/java/org/apache/spark/sql/SaveMode.java index 3109f5716da2c..a40be526d0d11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java +++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.sources; +package org.apache.spark.sql; /** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9d5d6e78bd487..f6ecee1af8aad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql -import scala.annotation.tailrec import scala.language.implicitConversions -import org.apache.spark.sql.Dsl.lit +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.types._ @@ -127,7 +126,7 @@ trait Column extends DataFrame { * df.select( -df("amount") ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.select( negate(col("amount") ); * }}} */ @@ -140,7 +139,7 @@ trait Column extends DataFrame { * df.filter( !df("isActive") ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( not(df.col("isActive")) ); * }} */ @@ -153,7 +152,7 @@ trait Column extends DataFrame { * df.filter( df("colA") === df("colB") ) * * // Java - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( col("colA").equalTo(col("colB")) ); * }}} */ @@ -168,7 +167,7 @@ trait Column extends DataFrame { * df.filter( df("colA") === df("colB") ) * * // Java - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( col("colA").equalTo(col("colB")) ); * }}} */ @@ -182,7 +181,7 @@ trait Column extends DataFrame { * df.select( !(df("colA") === df("colB")) ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( col("colA").notEqual(col("colB")) ); * }}} */ @@ -198,7 +197,7 @@ trait Column extends DataFrame { * df.select( !(df("colA") === df("colB")) ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.filter( col("colA").notEqual(col("colB")) ); * }}} */ @@ -213,7 +212,7 @@ trait Column extends DataFrame { * people.select( people("age") > 21 ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * people.select( people("age").gt(21) ); * }}} */ @@ -228,7 +227,7 @@ trait Column extends DataFrame { * people.select( people("age") > lit(21) ) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * people.select( people("age").gt(21) ); * }}} */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 13aff760e9a5c..e21e989f36c65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -26,7 +27,6 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.sources.SaveMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -48,7 +48,7 @@ private[sql] object DataFrame { * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions - * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL. + * defined in: [[DataFrame]] (this class), [[Column]], [[functions]] for the DSL. * * To select a column from the data frame, use the apply method: * {{{ @@ -94,27 +94,27 @@ trait DataFrame extends RDDApi[Row] with Serializable { } /** Left here for backward compatibility. */ - @deprecated("1.3.0", "use toDataFrame") + @deprecated("1.3.0", "use toDF") def toSchemaRDD: DataFrame = this /** * Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala. */ // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDataFrame("1")` as invoking this toDataFrame and then apply on the returned DataFrame. - def toDataFrame(): DataFrame = this + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = this /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: * {{{ * val rdd: RDD[(Int, String)] = ... - * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 - * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * rdd.toDF // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" * }}} */ @scala.annotation.varargs - def toDataFrame(colNames: String*): DataFrame + def toDF(colNames: String*): DataFrame /** Returns the schema of this [[DataFrame]]. */ def schema: StructType @@ -132,7 +132,7 @@ trait DataFrame extends RDDApi[Row] with Serializable { def explain(extended: Boolean): Unit /** Only prints the physical plan to the console for debugging purpose. */ - def explain(): Unit = explain(false) + def explain(): Unit = explain(extended = false) /** * Returns true if the `collect` and `take` methods can be run locally @@ -179,11 +179,11 @@ trait DataFrame extends RDDApi[Row] with Serializable { * * {{{ * // Scala: - * import org.apache.spark.sql.dsl._ + * import org.apache.spark.sql.functions._ * df1.join(df2, "outer", $"df1Key" === $"df2Key") * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df1.join(df2, "outer", col("df1Key") === col("df2Key")); * }}} * @@ -441,17 +441,54 @@ trait DataFrame extends RDDApi[Row] with Serializable { sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more + * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of + * the input row are implicitly joined with each row that is output by the function. + * + * The following example uses this function to count the number of books which contain + * a given word: + * + * {{{ + * case class Book(title: String, words: String) + * val df: RDD[Book] + * + * case class Word(word: String) + * val allWords = df.explode('words) { + * case Row(words: String) => words.split(" ").map(Word(_)) + * } + * + * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) + * }}} + */ + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame + + + /** + * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero + * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the function. + * + * {{{ + * df.explode("words", "word")(words: String => words.split(" ")) + * }}} + */ + def explode[A, B : TypeTag]( + inputColumn: String, + outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame + ///////////////////////////////////////////////////////////////////////////// /** * Returns a new [[DataFrame]] by adding a column. */ - def addColumn(colName: String, col: Column): DataFrame + def withColumn(colName: String, col: Column): DataFrame /** * Returns a new [[DataFrame]] with a column renamed. */ - def renameColumn(existingName: String, newName: String): DataFrame + def withColumnRenamed(existingName: String, newName: String): DataFrame /** * Returns the first `n` rows. @@ -483,6 +520,7 @@ trait DataFrame extends RDDApi[Row] with Serializable { * Returns a new RDD by applying a function to each partition of this DataFrame. */ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] + /** * Applies a function `f` to all rows. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala new file mode 100644 index 0000000000000..a3187fe3230fd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -0,0 +1,30 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * A container for a [[DataFrame]], used for implicit conversions. + */ +private[sql] case class DataFrameHolder(df: DataFrame) { + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = df + + def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 4c6e19cace8ca..7b7efbe3477b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -21,6 +21,7 @@ import java.io.CharArrayWriter import scala.language.implicitConversions import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.collection.JavaConversions._ import com.fasterxml.jackson.core.JsonFactory @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} +import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{NumericType, StructType} - /** * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly. */ @@ -94,7 +94,7 @@ private[sql] class DataFrameImpl protected[sql]( } } - override def toDataFrame(colNames: String*): DataFrame = { + override def toDF(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + @@ -229,11 +229,11 @@ private[sql] class DataFrameImpl protected[sql]( }: _*) } - override def addColumn(colName: String, col: Column): DataFrame = { + override def withColumn(colName: String, col: Column): DataFrame = { select(Column("*"), col.as(colName)) } - override def renameColumn(existingName: String, newName: String): DataFrame = { + override def withColumnRenamed(existingName: String, newName: String): DataFrame = { val colNames = schema.map { field => val name = field.name if (name == existingName) Column(name).as(newName) else Column(name) @@ -282,6 +282,32 @@ private[sql] class DataFrameImpl protected[sql]( Sample(fraction, withReplacement, seed, logicalPlan) } + override def explode[A <: Product : TypeTag] + (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributes = schema.toAttributes + val rowFunction = + f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + + Generate(generator, join = true, outer = false, None, logicalPlan) + } + + override def explode[A, B : TypeTag]( + inputColumn: String, + outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame = { + val dataType = ScalaReflection.schemaFor[B].dataType + val attributes = AttributeReference(outputColumn, dataType)() :: Nil + def rowFunction(row: Row) = { + f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + } + val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + + Generate(generator, join = true, outer = false, None, logicalPlan) + + } + ///////////////////////////////////////////////////////////////////////////// // RDD API ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 3c20676355c9d..0868013fe7c96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.collection.JavaConversions._ +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate */ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) { - private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val namedGroupingExprs = groupingExprs.map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.toString)() @@ -52,7 +52,12 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio case "max" => Max case "min" => Min case "sum" => Sum - case "count" | "size" => Count + case "count" | "size" => + // Turn count(*) into count(1) + (inputExpr: Expression) => inputExpr match { + case s: Star => Count(Literal(1)) + case _ => Count(inputExpr) + } } } @@ -115,17 +120,17 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this * class, the resulting [[DataFrame]] won't automatically include the grouping columns. * - * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]]. + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. * * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * * // Scala: - * import org.apache.spark.sql.dsl._ + * import org.apache.spark.sql.functions._ * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) * * // Java: - * import static org.apache.spark.sql.Dsl.*; + * import static org.apache.spark.sql.functions.*; * df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense"))); * }}} */ @@ -142,7 +147,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio * Count the number of rows for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ - def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")()) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 4f9d92d97646f..fc37cfa7a899f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD @@ -25,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedSt import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.sources.SaveMode import org.apache.spark.sql.types.StructType private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column { @@ -50,7 +50,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten protected[sql] override def logicalPlan: LogicalPlan = err() - override def toDataFrame(colNames: String*): DataFrame = err() + override def toDF(colNames: String*): DataFrame = err() override def schema: StructType = err() @@ -86,9 +86,9 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def selectExpr(exprs: String*): DataFrame = err() - override def addColumn(colName: String, col: Column): DataFrame = err() + override def withColumn(colName: String, col: Column): DataFrame = err() - override def renameColumn(existingName: String, newName: String): DataFrame = err() + override def withColumnRenamed(existingName: String, newName: String): DataFrame = err() override def filter(condition: Column): DataFrame = err() @@ -110,6 +110,14 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err() + override def explode[A <: Product : TypeTag] + (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err() + + override def explode[A, B : TypeTag]( + inputColumn: String, + outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame = err() + ///////////////////////////////////////////////////////////////////////////// override def head(n: Int): Array[Row] = err() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8aae222acd927..b42a52ebd2f16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -183,14 +183,25 @@ class SQLContext(@transient val sparkContext: SparkContext) object implicits extends Serializable { // scalastyle:on + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + /** Creates a DataFrame from an RDD of case classes or tuples. */ - implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - self.createDataFrame(rdd) + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(self.createDataFrame(rdd)) } /** Creates a DataFrame from a local Seq of Product. */ - implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - self.createDataFrame(data) + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(self.createDataFrame(data)) } // Do NOT add more implicit conversions. They are likely to break source compatibility by @@ -198,7 +209,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // because of [[DoubleRDDFunctions]]. /** Creates a single column DataFrame from an RDD[Int]. */ - implicit def intRddToDataFrame(data: RDD[Int]): DataFrame = { + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { val dataType = IntegerType val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) @@ -207,11 +218,11 @@ class SQLContext(@transient val sparkContext: SparkContext) row: Row } } - self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)) + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** Creates a single column DataFrame from an RDD[Long]. */ - implicit def longRddToDataFrame(data: RDD[Long]): DataFrame = { + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { val dataType = LongType val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) @@ -220,11 +231,11 @@ class SQLContext(@transient val sparkContext: SparkContext) row: Row } } - self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)) + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** Creates a single column DataFrame from an RDD[String]. */ - implicit def stringRddToDataFrame(data: RDD[String]): DataFrame = { + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { val dataType = StringType val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) @@ -233,7 +244,7 @@ class SQLContext(@transient val sparkContext: SparkContext) row: Row } } - self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)) + DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } } @@ -275,6 +286,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * val sqlContext = new org.apache.spark.sql.SQLContext(sc) * * val schema = @@ -366,6 +378,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * val sqlContext = new org.apache.spark.sql.SQLContext(sc) * * val schema = @@ -433,7 +446,7 @@ class SQLContext(@transient val sparkContext: SparkContext) baseRelationToDataFrame(parquet.ParquetRelation2(path +: paths, Map.empty)(this)) } else { DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + (path +: paths).mkString(","), Some(sparkContext.hadoopConfiguration), this)) } /** @@ -774,6 +787,42 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): DataFrame = DataFrame(this, catalog.lookupRelation(Seq(tableName))) + /** + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + */ + def tables(): DataFrame = { + createDataFrame(catalog.getTables(None)).toDF("tableName", "isTemporary") + } + + /** + * Returns a [[DataFrame]] containing names of existing tables in the given database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + */ + def tables(databaseName: String): DataFrame = { + createDataFrame(catalog.getTables(Some(databaseName))).toDF("tableName", "isTemporary") + } + + /** + * Returns the names of tables in the current database as an array. + */ + def tableNames(): Array[String] = { + catalog.getTables(None).map { + case (tableName, _) => tableName + }.toArray + } + + /** + * Returns the names of tables in the given database as an array. + */ + def tableNames(databaseName: String): Array[String] = { + catalog.getTables(Some(databaseName)).map { + case (tableName, _) => tableName + }.toArray + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index c60d4070942a9..ee94a5fdbe376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType /** - * A user-defined function. To create one, use the `udf` functions in [[Dsl]]. + * A user-defined function. To create one, use the `udf` functions in [[functions]]. * As an example: * {{{ * // Defined a UDF that returns true or false based on some numeric score. @@ -45,7 +45,7 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) { } /** - * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]]. + * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. * This is used by Python API. */ private[sql] case class UserDefinedPythonFunction( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7bc7683576b71..4a0ec0b72ce81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -28,17 +29,9 @@ import org.apache.spark.sql.types._ /** * Domain specific functions available for [[DataFrame]]. */ -object Dsl { - - /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** Converts $"col name" into an [[Column]]. */ - implicit class StringToColumn(val sc: StringContext) extends AnyVal { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args :_*)) - } - } +// scalastyle:off +object functions { +// scalastyle:on private[this] implicit def toColumn(expr: Expression): Column = Column(expr) @@ -104,7 +97,11 @@ object Dsl { def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) /** Aggregate function: returns the number of items in a group. */ - def count(e: Column): Column = Count(e.expr) + def count(e: Column): Column = e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } /** Aggregate function: returns the number of items in a group. */ def count(columnName: String): Column = count(Column(columnName)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 51ff2443f3717..24848634de9cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -21,7 +21,7 @@ import java.io.IOException import org.apache.hadoop.fs.Path -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 1043eefcfc6a5..3b8dde1823370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.json -import java.io.StringWriter -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} -import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException, JsonFactory} +import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -178,7 +177,12 @@ private[sql] object JsonRDD extends Logging { } private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { - ScalaReflection.typeOfObject orElse { + // For Integer values, use LongType by default. + val useLongType: PartialFunction[Any, DataType] = { + case value: IntegerType.JvmType => LongType + } + + useLongType orElse ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType.Unlimited @@ -302,6 +306,10 @@ private[sql] object JsonRDD extends Logging { val parsed = mapper.readValue(record, classOf[Object]) match { case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of the file " + + "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") } parsed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cd17fde46ab..7dd8bea49b8a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -647,6 +647,6 @@ private[parquet] object FileSystemHelper { sys.error("ERROR: attempting to append to set of Parquet files and found file" + s"that does not match name pattern: $other") case _ => 0 - }.reduceLeft((a, b) => if (a < b) b else a) + }.reduceOption(_ max _).getOrElse(0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 8d3e094e3344d..d0856df8d4f43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -37,7 +37,8 @@ import org.apache.spark.util.Utils trait ParquetTest { val sqlContext: SQLContext - import sqlContext._ + import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} + import sqlContext.{conf, sparkContext} protected def configuration = sparkContext.hadoopConfiguration @@ -49,11 +50,11 @@ trait ParquetTest { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(getConf(key)).toOption) - (keys, values).zipped.foreach(setConf) + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => setConf(key, value) + case (key, Some(value)) => conf.setConf(key, value) case (key, None) => conf.unsetConf(key) } } @@ -88,9 +89,8 @@ trait ParquetTest { protected def withParquetFile[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: String => Unit): Unit = { - import sqlContext.implicits._ withTempPath { file => - sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -102,14 +102,14 @@ trait ParquetTest { protected def withParquetRDD[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(parquetFile(path))) + withParquetFile(data)(path => f(sqlContext.parquetFile(path))) } /** * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally dropTempTable(tableName) + try f finally sqlContext.dropTempTable(tableName) } /** @@ -125,4 +125,26 @@ trait ParquetTest { withTempTable(tableName)(f) } } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index aef9c10fbcd01..9279f5a903f55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -20,7 +20,7 @@ import java.io.IOException import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.text.SimpleDateFormat -import java.util.{List => JList, Date} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer @@ -34,8 +34,9 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetInputFormat, _} +import parquet.hadoop.metadata.CompressionCodecName import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil @@ -45,21 +46,35 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} -import org.apache.spark.sql.types.StructType._ -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} -import org.apache.spark.{Partition => SparkPartition, TaskContext, SerializableWritable, Logging, SparkException} - +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} /** - * Allows creation of parquet based tables using the syntax - * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option - * required is `path`, which should be the location of a collection of, optionally partitioned, - * parquet files. + * Allows creation of Parquet based tables using the syntax: + * {{{ + * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...) + * }}} + * + * Supported options include: + * + * - `path`: Required. When reading Parquet files, `path` should point to the location of the + * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files. + * In the latter case, this data source tries to discover partitioning information if the the + * directory is structured in the same style of Hive partitioned tables. When writing Parquet + * file, `path` should point to the destination folder. + * + * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but + * compatible) schemas stored in all Parquet part-files. + * + * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is + * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration + * in Hive. */ class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { + private def checkPath(parameters: Map[String, String]): String = { parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) } @@ -71,6 +86,7 @@ class DefaultSource ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) } + /** Returns a new base relation with the given parameters and schema. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], @@ -78,6 +94,7 @@ class DefaultSource ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) } + /** Returns a new base relation with the given parameters and save given data into it. */ override def createRelation( sqlContext: SQLContext, mode: SaveMode, @@ -86,33 +103,19 @@ class DefaultSource val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val doSave = if (fs.exists(filesystemPath)) { - mode match { - case SaveMode.Append => - sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") - case SaveMode.Overwrite => - fs.delete(filesystemPath, true) - true - case SaveMode.ErrorIfExists => - sys.error(s"path $path already exists.") - case SaveMode.Ignore => false - } - } else { - true + val doInsertion = (mode, fs.exists(filesystemPath)) match { + case (SaveMode.ErrorIfExists, true) => + sys.error(s"path $path already exists.") + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists } - val relation = if (doSave) { - // Only save data when the save mode is not ignore. - ParquetRelation.createEmpty( - path, - data.schema.toAttributes, - false, - sqlContext.sparkContext.hadoopConfiguration, - sqlContext) - - val createdRelation = createRelation(sqlContext, parameters, data.schema) - createdRelation.asInstanceOf[ParquetRelation2].insert(data, true) - + val relation = if (doInsertion) { + val createdRelation = + createRelation(sqlContext, parameters, data.schema).asInstanceOf[ParquetRelation2] + createdRelation.insert(data, overwrite = mode == SaveMode.Overwrite) createdRelation } else { // If the save mode is Ignore, we will just create the relation based on existing data. @@ -123,37 +126,31 @@ class DefaultSource } } -private[parquet] case class Partition(values: Row, path: String) +private[sql] case class Partition(values: Row, path: String) -private[parquet] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * currently not intended as a full replacement of the parquet support in Spark SQL though it is - * likely that it will eventually subsume the existing physical plan implementation. - * - * Compared with the current implementation, this class has the following notable differences: - * - * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/` - * located at `path`. Currently only a single partitioning column is supported and it must - * be an integer. This class supports both fully self-describing data, which contains the partition - * key, and data where the partition key is only present in the folder structure. The presence - * of the partitioning key in the data is also auto-detected. The `null` partition is not yet - * supported. + * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will + * be deprecated and eventually removed once this version is proved to be stable enough. * - * Metadata: The metadata is automatically discovered by reading the first parquet file present. - * There is currently no support for working with files that have different schema. Additionally, - * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached - * to improve the speed of interactive querying. When data is added to a table it must be dropped - * and recreated to pick up any changes. + * Compared with the old implementation, this class has the following notable differences: * - * Statistics: Statistics for the size of the table are automatically populated during metadata - * discovery. + * - Partitioning discovery: Hive style multi-level partitions are auto discovered. + * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source + * can detect and merge schemas from all Parquet part-files as long as they are compatible. + * Also, metadata and [[FileStatus]]es are cached for better performance. + * - Statistics: Statistics for the size of the table are automatically populated during schema + * discovery. */ @DeveloperApi -case class ParquetRelation2 - (paths: Seq[String], parameters: Map[String, String], maybeSchema: Option[StructType] = None) - (@transient val sqlContext: SQLContext) +case class ParquetRelation2( + paths: Seq[String], + parameters: Map[String, String], + maybeSchema: Option[StructType] = None, + maybePartitionSpec: Option[PartitionSpec] = None)( + @transient val sqlContext: SQLContext) extends CatalystScan with InsertableRelation with SparkHadoopMapReduceUtil @@ -176,43 +173,90 @@ case class ParquetRelation2 override def equals(other: Any) = other match { case relation: ParquetRelation2 => + // If schema merging is required, we don't compare the actual schemas since they may evolve. + val schemaEquality = if (shouldMergeSchemas) { + shouldMergeSchemas == relation.shouldMergeSchemas + } else { + schema == relation.schema + } + paths.toSet == relation.paths.toSet && + schemaEquality && maybeMetastoreSchema == relation.maybeMetastoreSchema && - (shouldMergeSchemas == relation.shouldMergeSchemas || schema == relation.schema) + maybePartitionSpec == relation.maybePartitionSpec + case _ => false } private[sql] def sparkContext = sqlContext.sparkContext - @transient private val fs = FileSystem.get(sparkContext.hadoopConfiguration) - private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. private var metadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all "_common_metadata" files. private var commonMetadataStatuses: Array[FileStatus] = _ + + // Parquet footer cache. private var footers: Map[FileStatus, Footer] = _ - private var parquetSchema: StructType = _ + // `FileStatus` objects of all data files (Parquet part-files). var dataStatuses: Array[FileStatus] = _ + + // Partition spec of this table, including names, data types, and values of each partition + // column, and paths of each partition. var partitionSpec: PartitionSpec = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var parquetSchema: StructType = _ + + // Schema of the whole table, including partition columns. var schema: StructType = _ - var dataSchemaIncludesPartitionKeys: Boolean = _ + // Indicates whether partition columns are also included in Parquet data file schema. If not, + // we need to fill in partition column values into read rows when scanning the table. + var partitionKeysIncludedInParquetSchema: Boolean = _ + + def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = { + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + ParquetRelation.enableLogForwarding() + ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf) + } + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ def refresh(): Unit = { - val baseStatuses = { - val statuses = paths.distinct.map(p => fs.getFileStatus(fs.makeQualified(new Path(p)))) - // Support either reading a collection of raw Parquet part-files, or a collection of folders - // containing Parquet files (e.g. partitioned Parquet table). - assert(statuses.forall(!_.isDir) || statuses.forall(_.isDir)) - statuses.toArray - } + val fs = FileSystem.get(sparkContext.hadoopConfiguration) + + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + val baseStatuses = paths.distinct.map { p => + val qualified = fs.makeQualified(new Path(p)) + + if (!fs.exists(qualified) && maybeSchema.isDefined) { + fs.mkdirs(qualified) + prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration) + } + + fs.getFileStatus(qualified) + }.toArray + assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. val leaves = baseStatuses.flatMap { f => - val statuses = SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => isSummaryFile(f.getPath) || !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) } - assert(statuses.nonEmpty, s"${f.getPath} is an empty folder.") - statuses } dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) @@ -226,13 +270,14 @@ case class ParquetRelation2 f -> new Footer(f.getPath, parquetMetadata) }.seq.toMap - partitionSpec = { - val partitionDirs = dataStatuses + partitionSpec = maybePartitionSpec.getOrElse { + val partitionDirs = leaves .filterNot(baseStatuses.contains) .map(_.getPath.getParent) .distinct if (partitionDirs.nonEmpty) { + // Parses names and values of partition columns, and infer their data types. ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName) } else { // No partition directories found, makes an empty specification @@ -242,20 +287,22 @@ case class ParquetRelation2 parquetSchema = maybeSchema.getOrElse(readSchema()) - dataSchemaIncludesPartitionKeys = + partitionKeysIncludedInParquetSchema = isPartitioned && - partitionColumns.forall(f => metadataCache.parquetSchema.fieldNames.contains(f.name)) + partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name)) schema = { - val fullParquetSchema = if (dataSchemaIncludesPartitionKeys) { - metadataCache.parquetSchema + val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) { + parquetSchema } else { - StructType(metadataCache.parquetSchema.fields ++ partitionColumns.fields) + StructType(parquetSchema.fields ++ partitionColumns.fields) } + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // insensitivity issue and possible schema mismatch. maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullParquetSchema)) - .getOrElse(fullParquetSchema) + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema)) + .getOrElse(fullRelationSchema) } } @@ -304,13 +351,17 @@ case class ParquetRelation2 @transient private val metadataCache = new MetadataCache metadataCache.refresh() - private def partitionColumns = metadataCache.partitionSpec.partitionColumns + def partitionSpec = metadataCache.partitionSpec - private def partitions = metadataCache.partitionSpec.partitions + def partitionColumns = metadataCache.partitionSpec.partitionColumns - private def isPartitioned = partitionColumns.nonEmpty + def partitions = metadataCache.partitionSpec.partitions - private def dataSchemaIncludesPartitionKeys = metadataCache.dataSchemaIncludesPartitionKeys + def isPartitioned = partitionColumns.nonEmpty + + private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema + + private def parquetSchema = metadataCache.parquetSchema override def schema = metadataCache.schema @@ -413,18 +464,21 @@ case class ParquetRelation2 // When the data does not include the key and the key is requested then we must fill it in // based on information from the input split. - if (!dataSchemaIncludesPartitionKeys && partitionKeyLocations.nonEmpty) { + if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) { baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { case p if split.getPath.getParent.toString == p.path => p.values }.get + val requiredPartOrdinal = partitionKeyLocations.keys.toSeq + iterator.map { pair => val row = pair._2.asInstanceOf[SpecificMutableRow] var i = 0 - while (i < partValues.size) { + while (i < requiredPartOrdinal.size) { // TODO Avoids boxing cost here! - row.update(partitionKeyLocations(i), partValues(i)) + val partOrdinal = requiredPartOrdinal(i) + row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) i += 1 } row @@ -458,6 +512,8 @@ case class ParquetRelation2 } override def insert(data: DataFrame, overwrite: Boolean): Unit = { + assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}") + // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -465,7 +521,7 @@ case class ParquetRelation2 // before calling execute(). val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val writeSupport = if (schema.map(_.dataType).forall(_.isPrimitive)) { + val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) { log.debug("Initializing MutableRowWriteSupport") classOf[MutableRowWriteSupport] } else { @@ -475,7 +531,7 @@ case class ParquetRelation2 ParquetOutputFormat.setWriteSupportClass(job, writeSupport) val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(schema.toAttributes, conf) + RowWriteSupport.setSchema(data.schema.toAttributes, conf) val destinationPath = new Path(paths.head) @@ -545,14 +601,12 @@ object ParquetRelation2 { // Whether we should merge schemas collected from all Parquet part-files. val MERGE_SCHEMA = "mergeSchema" - // Hive Metastore schema, passed in when the Parquet relation is converted from Metastore - val METASTORE_SCHEMA = "metastoreSchema" - - // Default partition name to use when the partition column value is null or empty string + // Default partition name to use when the partition column value is null or empty string. val DEFAULT_PARTITION_NAME = "partition.defaultName" - // When true, the Parquet data source caches Parquet metadata for performance - val CACHE_METADATA = "cacheMetadata" + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" private[parquet] def readSchema(footers: Seq[Footer], sqlContext: SQLContext): StructType = { footers.map { footer => @@ -580,6 +634,15 @@ object ParquetRelation2 { } } + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ private[parquet] def mergeMetastoreParquetSchema( metastoreSchema: StructType, parquetSchema: StructType): StructType = { @@ -720,16 +783,15 @@ object ParquetRelation2 { * }}} */ private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { - val distinctColNamesOfPartitions = values.map(_.columnNames).distinct - val columnCount = values.head.columnNames.size - // Column names of all partitions must match - assert(distinctColNamesOfPartitions.size == 1, { - val list = distinctColNamesOfPartitions.mkString("\t", "\n", "") + val distinctPartitionsColNames = values.map(_.columnNames).distinct + assert(distinctPartitionsColNames.size == 1, { + val list = distinctPartitionsColNames.mkString("\t", "\n", "") s"Conflicting partition column names detected:\n$list" }) // Resolves possible type conflicts for each column + val columnCount = values.head.columnNames.size val resolvedValues = (0 until columnCount).map { i => resolveTypeConflicts(values.map(_.literals(i))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index d3d72089c3303..8cac9c0fdf7fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 37fda7ba6e5d0..0c4b706eeebae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{SaveMode, DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 2e6e977fdc752..643b891ab1b63 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -164,7 +164,7 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); - fields.add(DataTypes.createStructField("integer", DataTypes.IntegerType, true)); + fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java index 639436368c4a3..05233dc5ffc58 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java @@ -23,7 +23,7 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.types.DataTypes; -import static org.apache.spark.sql.Dsl.*; +import static org.apache.spark.sql.functions.*; /** * This test doesn't actually run anything. It is here to check the API compatibility for Java. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1318750a4a3b0..691dae0a0561b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,8 +25,9 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -34,8 +35,6 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - import org.apache.spark.sql.test.TestSQLContext.implicits._ - def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { @@ -95,7 +94,7 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData") table("bigData").persist(StorageLevel.MEMORY_AND_DISK) assert(table("bigData").count() === 200000L) table("bigData").unpersist(blocking = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index e3e6f652ed3ed..a63d733ece627 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} @@ -68,7 +68,7 @@ class ColumnExpressionSuite extends QueryTest { } test("collect on column produced by a binary operator") { - val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c") + val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df("a") + df("b"), Seq(Row(3))) checkAnswer(df("a") + df("b").as("c"), Seq(Row(3))) } @@ -79,7 +79,7 @@ class ColumnExpressionSuite extends QueryTest { test("star qualified by data frame object") { // This is not yet supported. - val df = testData.toDataFrame + val df = testData.toDF val goldAnswer = df.collect().toSeq checkAnswer(df.select(df("*")), goldAnswer) @@ -156,13 +156,13 @@ class ColumnExpressionSuite extends QueryTest { test("isNull") { checkAnswer( - nullStrings.toDataFrame.where($"s".isNull), + nullStrings.toDF.where($"s".isNull), nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) } test("isNotNull") { checkAnswer( - nullStrings.toDataFrame.where($"s".isNotNull), + nullStrings.toDF.where($"s".isNotNull), nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 8fa830dd9390f..2d2367d6e7292 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -25,31 +25,31 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD of tuples") { checkAnswer( - sc.parallelize(1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"), + sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } test("Seq of tuples") { checkAnswer( - (1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"), + (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } test("RDD[Int]") { checkAnswer( - sc.parallelize(1 to 10).toDataFrame("intCol"), + sc.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - sc.parallelize(1L to 10L).toDataFrame("longCol"), + sc.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - sc.parallelize(1 to 10).map(_.toString).toDataFrame("stringCol"), + sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7be9215a443f0..f0cd43632ec3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.TestData._ import scala.language.postfixOps -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery @@ -98,6 +98,31 @@ class DataFrameSuite extends QueryTest { sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) } + test("simple explode") { + val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") + + checkAnswer( + df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), + Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil + ) + } + + test("explode") { + val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") + val df2 = + df.explode('letters) { + case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq + } + + checkAnswer( + df2 + .select('_1 as 'letter, 'number) + .groupBy('letter) + .agg('letter, countDistinct('number)), + Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil + ) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), @@ -116,15 +141,30 @@ class DataFrameSuite extends QueryTest { testData.select('key).collect().toSeq) } - test("agg") { + test("groupBy") { checkAnswer( testData2.groupBy("a").agg($"a", sum($"b")), - Seq(Row(1,3), Row(2,3), Row(3,3)) + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) checkAnswer( testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) + checkAnswer( + testData2.groupBy("a").agg(col("a"), count("*")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("*" -> "count")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("b" -> "sum")), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + } + + test("agg without groups") { checkAnswer( testData2.agg(sum('b)), Row(9) @@ -193,20 +233,20 @@ class DataFrameSuite extends QueryTest { Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) checkAnswer( - arrayData.orderBy('data.getItem(0).asc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDF.orderBy('data.getItem(0).asc), + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( - arrayData.orderBy('data.getItem(0).desc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDF.orderBy('data.getItem(0).desc), + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( - arrayData.orderBy('data.getItem(1).asc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDF.orderBy('data.getItem(1).asc), + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( - arrayData.orderBy('data.getItem(1).desc), - arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) + arrayData.toDF.orderBy('data.getItem(1).desc), + arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -215,11 +255,11 @@ class DataFrameSuite extends QueryTest { testData.take(10).toSeq) checkAnswer( - arrayData.limit(1), + arrayData.toDF.limit(1), arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) checkAnswer( - mapData.limit(1), + mapData.toDF.limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } @@ -353,7 +393,7 @@ class DataFrameSuite extends QueryTest { } test("addColumn") { - val df = testData.toDataFrame.addColumn("newCol", col("key") + 1) + val df = testData.toDF.withColumn("newCol", col("key") + 1) checkAnswer( df, testData.collect().map { case Row(key: Int, value: String) => @@ -363,8 +403,8 @@ class DataFrameSuite extends QueryTest { } test("renameColumn") { - val df = testData.toDataFrame.addColumn("newCol", col("key") + 1) - .renameColumn("value", "valueRenamed") + val df = testData.toDF.withColumn("newCol", col("key") + 1) + .withColumnRenamed("value", "valueRenamed") checkAnswer( df, testData.collect().map { case Row(key: Int, value: String) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f0c939dbb195f..fd73065c4ada3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala new file mode 100644 index 0000000000000..282b98a987dd4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -0,0 +1,76 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} + +class ListTablesSuite extends QueryTest with BeforeAndAfter { + + import org.apache.spark.sql.test.TestSQLContext.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + + before { + df.registerTempTable("ListTablesSuiteTable") + } + + after { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + } + + test("get all tables") { + checkAnswer( + tables().filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("getting all Tables with a database name has no impact on returned table names") { + checkAnswer( + tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + Row("ListTablesSuiteTable", true)) + + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + } + + test("query the returned DataFrame of tables") { + val tableDF = tables() + val schema = StructType( + StructField("tableName", StringType, true) :: + StructField("isTemporary", BooleanType, false) :: Nil) + assert(schema === tableDF.schema) + + tableDF.registerTempTable("tables") + checkAnswer( + sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + Row(true, "ListTablesSuiteTable") + ) + checkAnswer( + tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + Row("tables", true)) + dropTempTable("tables") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bba8899651259..97684f75e79fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -806,10 +806,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("throw errors for non-aggregate attributes with aggregation") { def checkAggregation(query: String, isInvalidQuery: Boolean = true) { if (isInvalidQuery) { - val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) - assert( - e.getMessage.startsWith("Expression not in GROUP BY"), - "Non-aggregate attribute(s) not detected\n") + val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) + assert(e.getMessage contains "group by") } else { // Should not throw sql(query).queryExecution.analyzed @@ -1036,10 +1034,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) - rdd1.registerTempTable("nulldata1") + rdd1.toDF.registerTempTable("nulldata1") val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) - rdd2.registerTempTable("nulldata2") + rdd2.toDF.registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) @@ -1048,7 +1046,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.registerTempTable("distinctData") + rdd.toDF.registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 93782619826f0..9a48f8d0634cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectData") + rdd.toDF.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, @@ -93,7 +93,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectNullData") + rdd.toDF.registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -101,7 +101,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectOptionalData") + rdd.toDF.registerTempTable("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -109,7 +109,7 @@ class ScalaReflectionRelationSuite extends FunSuite { // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerTempTable("reflectBinary") + rdd.toDF.registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) @@ -128,7 +128,7 @@ class ScalaReflectionRelationSuite extends FunSuite { Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerTempTable("reflectComplexData") + rdd.toDF.registerTempTable("reflectComplexData") assert(sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 0ed437edd05fd..c511eb1469167 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test._ import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -29,11 +29,11 @@ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDataFrame + (1 to 100).map(i => TestData(i, i.toString))).toDF testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -44,7 +44,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDataFrame + LargeAndSmallInts(3, 2) :: Nil).toDF largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -55,7 +55,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDataFrame + TestData2(3, 2) :: Nil, 2).toDF testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -67,7 +67,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDataFrame + DecimalData(3, 2) :: Nil).toDF decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -77,14 +77,14 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDataFrame + BinaryData("123".getBytes(), 4) :: Nil).toDF binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDataFrame + TestData3(2, Some(2)) :: Nil).toDF testData3.registerTempTable("testData3") val emptyTableData = logical.LocalRelation($"a".int, $"b".int) @@ -97,7 +97,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDataFrame + UpperCaseData(6, "F") :: Nil).toDF upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -106,7 +106,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDataFrame + LowerCaseData(4, "d") :: Nil).toDF lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -114,7 +114,7 @@ object TestData { TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerTempTable("arrayData") + arrayData.toDF.registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) val mapData = @@ -124,18 +124,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerTempTable("mapData") + mapData.toDF.registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerTempTable("repeatedData") + repeatedData.toDF.registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerTempTable("nullableRepeatedData") + nullableRepeatedData.toDF.registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -144,7 +144,7 @@ object TestData { NullInts(2) :: NullInts(3) :: NullInts(null) :: Nil - ) + ).toDF nullInts.registerTempTable("nullInts") val allNulls = @@ -152,7 +152,7 @@ object TestData { NullInts(null) :: NullInts(null) :: NullInts(null) :: - NullInts(null) :: Nil) + NullInts(null) :: Nil).toDF allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) @@ -160,11 +160,11 @@ object TestData { TestSQLContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDataFrame + NullStrings(3, null) :: Nil).toDF nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") + TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).toDF.registerTempTable("tableName") val unparsedStrings = TestSQLContext.sparkContext.parallelize( @@ -177,22 +177,22 @@ object TestData { val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i)) }) - timestamps.registerTempTable("timestamps") + timestamps.toDF.registerTempTable("timestamps") case class IntField(i: Int) // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.registerTempTable("withEmptyParts") + withEmptyParts.toDF.registerTempTable("withEmptyParts") case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) val person = TestSQLContext.sparkContext.parallelize( Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil) + Person(1, "jim", 20) :: Nil).toDF person.registerTempTable("person") val salary = TestSQLContext.sparkContext.parallelize( Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil) + Salary(1, 1000.0) :: Nil).toDF salary.registerTempTable("salary") case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) @@ -200,6 +200,6 @@ object TestData { TestSQLContext.sparkContext.parallelize( ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toDataFrame + :: Nil).toDF complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 95923f9aad931..be105c6e83594 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.Dsl.StringToColumn import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 3c1657cd5fc3a..5f21d990e2e5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -66,7 +66,7 @@ class UserDefinedTypeSuite extends QueryTest { val points = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + val pointsRDD = sparkContext.parallelize(points).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 86b1b5fda1c0f..38b0f666ab90b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.columnar -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -28,8 +29,6 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - test("simple columnar query") { val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) @@ -39,7 +38,8 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst") + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + .toDF().registerTempTable("sizeTst") cacheTable("sizeTst") assert( table("sizeTst").queryExecution.logical.statistics.sizeInBytes > diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 55a9f735b3506..e57bb06e7263b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,13 +21,12 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = conf.columnBatchSize val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning - import org.apache.spark.sql.test.TestSQLContext.implicits._ - override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") @@ -35,7 +34,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) - }, 5) + }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c3210733f1d42..523be56df65ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index fde4b47438c56..c94e44bd7c397 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} @@ -222,7 +223,7 @@ class JsonSuite extends QueryTest { StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: + StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) @@ -252,7 +253,7 @@ class JsonSuite extends QueryTest { StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: - StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: + StructField("arrayOfInteger", ArrayType(LongType, false), true) :: StructField("arrayOfLong", ArrayType(LongType, false), true) :: StructField("arrayOfNull", ArrayType(StringType, true), true) :: StructField("arrayOfString", ArrayType(StringType, false), true) :: @@ -265,7 +266,7 @@ class JsonSuite extends QueryTest { StructField("field1", BooleanType, true) :: StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( - StructField("field1", ArrayType(IntegerType, false), true) :: + StructField("field1", ArrayType(LongType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) assert(expectedSchema === jsonDF.schema) @@ -486,7 +487,7 @@ class JsonSuite extends QueryTest { val jsonDF = jsonRDD(complexFieldValueTypeConflict) val expectedSchema = StructType( - StructField("array", ArrayType(IntegerType, false), true) :: + StructField("array", ArrayType(LongType, false), true) :: StructField("num_struct", StringType, true) :: StructField("str_array", StringType, true) :: StructField("struct", StructType( @@ -540,7 +541,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("a", BooleanType, true) :: StructField("b", LongType, true) :: - StructField("c", ArrayType(IntegerType, false), true) :: + StructField("c", ArrayType(LongType, false), true) :: StructField("d", StructType( StructField("field", BooleanType, true) :: Nil), true) :: StructField("e", StringType, true) :: Nil) @@ -560,7 +561,7 @@ class JsonSuite extends QueryTest { StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: - StructField("integer", IntegerType, true) :: + StructField("integer", LongType, true) :: StructField("long", LongType, true) :: StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) @@ -781,12 +782,12 @@ class JsonSuite extends QueryTest { ArrayType(ArrayType(ArrayType(ArrayType(StringType, false), false), true), false), true) :: StructField("field2", ArrayType(ArrayType( - StructType(StructField("Test", IntegerType, true) :: Nil), false), true), true) :: + StructType(StructField("Test", LongType, true) :: Nil), false), true), true) :: StructField("field3", ArrayType(ArrayType( StructType(StructField("Test", StringType, true) :: Nil), true), false), true) :: StructField("field4", - ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) + ArrayType(ArrayType(ArrayType(LongType, false), true), false), true) :: Nil) assert(schema === jsonDF.schema) @@ -822,7 +823,7 @@ class JsonSuite extends QueryTest { val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") - val df2 = df1.toDataFrame + val df2 = df1.toDF val result = df2.toJSON.collect() assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") @@ -843,7 +844,7 @@ class JsonSuite extends QueryTest { val df3 = createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") - val df4 = df3.toDataFrame + val df4 = df3.toDF val result2 = df4.toJSON.collect() assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index f8117c21773ae..eb2d5f25290b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import org.scalatest.BeforeAndAfterAll import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} @@ -40,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { +class ParquetFilterSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( @@ -112,210 +113,224 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } - def run(prefix: String): Unit = { - test(s"$prefix: filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + test("filter pushdown - boolean") { + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPredicate('_1 === true, classOf[Eq[_]], true) - checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) - } + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } + } - test(s"$prefix: filter pushdown - short") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) - checkFilterPredicate( - Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate( - Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) - checkFilterPredicate( - Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], - Seq(Row(1), Row(4))) - } + test("filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) + checkFilterPredicate( + Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate( + Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) + checkFilterPredicate( + Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], + Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - integer") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - long") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - float") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - double") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate( - '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) - } + test("filter pushdown - string") { + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } + } - test(s"$prefix: filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes("UTF-8") - } + test("filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") + } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate( - '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) - } + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } +} + +class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - run("Parquet data source enabled") + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { - run("Parquet data source disabled") + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index c8ebbbc7d2eac..208f35761b807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -21,6 +21,9 @@ import scala.collection.JavaConversions._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.scalatest.BeforeAndAfterAll import parquet.example.data.simple.SimpleGroup import parquet.example.data.{Group, GroupWriter} import parquet.hadoop.api.WriteSupport @@ -30,15 +33,13 @@ import parquet.hadoop.{ParquetFileWriter, ParquetWriter} import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} -import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -63,9 +64,11 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { +class ParquetIOSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext + import sqlContext.implicits.localSeqToDataFrameHolder + /** * Writes `data` to a Parquet file, reads it back and check file contents. */ @@ -73,229 +76,281 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } - def run(prefix: String): Unit = { - test(s"$prefix: basic data types (without binary)") { - val data = (1 to 4).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - } - checkParquetFile(data) + test("basic data types (without binary)") { + val data = (1 to 4).map { i => + (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) } + checkParquetFile(data) + } - test(s"$prefix: raw binary") { - val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => - assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted - } + test("raw binary") { + val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) + withParquetRDD(data) { rdd => + assertResult(data.map(_._1.mkString(",")).sorted) { + rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } } + } - test(s"$prefix: string") { - val data = (1 to 4).map(i => Tuple1(i.toString)) - // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL - // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) - } + test("string") { + val data = (1 to 4).map(i => Tuple1(i.toString)) + // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL + // as we store Spark SQL schema in the extra metadata. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + } - test(s"$prefix: fixed-length decimals") { - import org.apache.spark.sql.test.TestSQLContext.implicits._ - - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - // Parquet doesn't allow column names with spaces, have to add an alias here - .select($"_1" cast decimal as "dec") - - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - withTempPath { dir => - val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) - } + test("fixed-length decimals") { + + def makeDecimalRDD(decimal: DecimalType): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(i / 100.0)) + .toDF + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") + + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) } + } - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() - } + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() } + } - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() - } + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() } } + } + + test("map") { + val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) + checkParquetFile(data) + } - test(s"$prefix: map") { - val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) - checkParquetFile(data) + test("array") { + val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) + checkParquetFile(data) + } + + test("struct") { + val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) } + } - test(s"$prefix: array") { - val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) - checkParquetFile(data) + test("nested struct with array of array as field") { + val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) } + } - test(s"$prefix: struct") { - val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) - } + test("nested map with struct as value type") { + val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) + withParquetRDD(data) { rdd => + checkAnswer(rdd, data.map { case Tuple1(m) => + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + }) } + } - test(s"$prefix: nested struct with array of array as field") { - val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) - } + test("nulls") { + val allNulls = ( + null.asInstanceOf[java.lang.Boolean], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[java.lang.Float], + null.asInstanceOf[java.lang.Double]) + + withParquetRDD(allNulls :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(5)(null): _*)) } + } - test(s"$prefix: nested map with struct as value type") { - val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => - Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) - }) - } + test("nones") { + val allNones = ( + None.asInstanceOf[Option[Int]], + None.asInstanceOf[Option[Long]], + None.asInstanceOf[Option[String]]) + + withParquetRDD(allNones :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(3)(null): _*)) } + } - test(s"$prefix: nulls") { - val allNulls = ( - null.asInstanceOf[java.lang.Boolean], - null.asInstanceOf[Integer], - null.asInstanceOf[java.lang.Long], - null.asInstanceOf[java.lang.Float], - null.asInstanceOf[java.lang.Double]) - - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(5)(null): _*)) - } + test("compression codec") { + def compressionCodecFor(path: String) = { + val codecs = ParquetTypesConverter + .readMetaData(new Path(path), Some(configuration)) + .getBlocks + .flatMap(_.getColumns) + .map(_.getCodec.name()) + .distinct + + assert(codecs.size === 1) + codecs.head } - test(s"$prefix: nones") { - val allNones = ( - None.asInstanceOf[Option[Int]], - None.asInstanceOf[Option[Long]], - None.asInstanceOf[Option[String]]) + val data = (0 until 10).map(i => (i, i.toString)) - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(3)(null): _*)) + def checkCompressionCodec(codec: CompressionCodecName): Unit = { + withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withParquetFile(data) { path => + assertResult(conf.parquetCompressionCodec.toUpperCase) { + compressionCodecFor(path) + } + } } } - test(s"$prefix: compression codec") { - def compressionCodecFor(path: String) = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) - codecs.head - } + // Checks default compression codec + checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) - val data = (0 until 10).map(i => (i, i.toString)) + checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) + checkCompressionCodec(CompressionCodecName.GZIP) + checkCompressionCodec(CompressionCodecName.SNAPPY) + } - def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { - withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) - } - } - } + test("read raw Parquet file") { + def makeRawParquetFile(path: Path): Unit = { + val schema = MessageTypeParser.parseMessageType( + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + |} + """.stripMargin) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + (0 until 10).foreach { i => + val record = new SimpleGroup(schema) + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + writer.write(record) } - // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + writer.close() + } - checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) - checkCompressionCodec(CompressionCodecName.GZIP) - checkCompressionCodec(CompressionCodecName.SNAPPY) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) } + } - test(s"$prefix: read raw Parquet file") { - def makeRawParquetFile(path: Path): Unit = { - val schema = MessageTypeParser.parseMessageType( - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - |} - """.stripMargin) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - (0 until 10).foreach { i => - val record = new SimpleGroup(schema) - record.add(0, i % 2 == 0) - record.add(1, i) - record.add(2, i.toLong) - record.add(3, i.toFloat) - record.add(4, i.toDouble) - writer.write(record) - } + test("write metadata") { + withTempPath { file => + val path = new Path(file.toURI.toString) + val fs = FileSystem.getLocal(configuration) + val attributes = ScalaReflection.attributesFor[(Int, String)] + ParquetTypesConverter.writeMetaData(attributes, path, configuration) - writer.close() - } + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - withTempDir { dir => - val path = new Path(dir.toURI.toString, "part-r-0.parquet") - makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) - } + val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val actualSchema = metaData.getFileMetaData.getSchema + val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + + actualSchema.checkContains(expectedSchema) + expectedSchema.checkContains(actualSchema) } + } - test(s"$prefix: write metadata") { - withTempPath { file => - val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + test("save - overwrite") { + withParquetFile((1 to 10).map(i => (i, i.toString))) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) + checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + } + } - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + test("save - ignore") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) + checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + } + } - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + test("save - throw") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + val errorMessage = intercept[Throwable] { + newData.toDF().save( + "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + }.getMessage + assert(errorMessage.contains("already exists")) + } + } - actualSchema.checkContains(expectedSchema) - expectedSchema.checkContains(actualSchema) - } + test("save - append") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) + checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) } } +} + +class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - run("Parquet data source enabled") + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { - run("Parquet data source disabled") + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index ae606d11a8f68..3bf0116c8f7e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -19,17 +19,24 @@ package org.apache.spark.sql.parquet import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.parquet.ParquetRelation2._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} -class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { override val sqlContext: SQLContext = TestSQLContext + import sqlContext._ + val defaultPartitionName = "__NULL__" test("column type inference") { @@ -112,6 +119,17 @@ class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + check(Seq( s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), @@ -123,4 +141,182 @@ class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index cba06835f9a61..d0665450cd766 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,103 +17,120 @@ package org.apache.spark.sql.parquet +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.{QueryTest, SQLConf} /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { +class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - def run(prefix: String): Unit = { - test(s"$prefix: simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) - } + test("simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) } + } - test(s"$prefix: appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) - } + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } + } - // This test case will trigger the NPE mentioned in - // https://issues.apache.org/jira/browse/PARQUET-151. - ignore(s"$prefix: overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM t") - checkAnswer(table("t"), data.map(Row.fromTuple)) - } + // This test case will trigger the NPE mentioned in + // https://issues.apache.org/jira/browse/PARQUET-151. + // Update: This also triggers SPARK-5746, should re enable it when we get both fixed. + ignore("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM t") + checkAnswer(table("t"), data.map(Row.fromTuple)) } + } - test(s"$prefix: self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + assertResult(4, "Field count mismatche")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) } + } - test(s"$prefix: nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) } + } - test(s"$prefix: nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) } + } - test(s"$prefix: SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } + test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } + } - test(s"$prefix: SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) - } + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) } } +} + +class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - run("Parquet data source enabled") + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { - run("Parquet data source disabled") + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index a51004567175c..607488ccfdd6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -22,7 +22,7 @@ import java.io.File import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.{SQLConf, DataFrame} +import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7ae6ed6f841bf..87b380f950979 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -22,26 +22,24 @@ import java.sql.Timestamp import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} -import org.apache.spark.sql.sources.{CreateTableUsing, DataSourceStrategy} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ /** @@ -104,7 +102,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ @Experimental def analyze(tableName: String) { - val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName))) + val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { case relation: MetastoreRelation => @@ -244,6 +242,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] lazy val analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = false) { override val extendedRules = + catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c78369d12cf55..6d794d0e11391 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,25 +20,25 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} - -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.hive.metastore.{Warehouse, TableType} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, FieldSchema} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} +import org.apache.hadoop.hive.metastore.{TableType, Warehouse} import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.CreateTableDesc import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} +import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec} import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -101,16 +101,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false - /** * - * Creates a data source table (a table created with USING clause) in Hive's metastore. - * Returns true when the table has been created. Otherwise, false. - * @param tableName - * @param userSpecifiedSchema - * @param provider - * @param options - * @param isExternal - * @return - */ + /** + * Creates a data source table (a table created with USING clause) in Hive's metastore. + * Returns true when the table has been created. Otherwise, false. + */ def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -141,7 +135,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } def hiveDefaultTableFilePath(tableName: String): String = { - val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase()) + val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase) hiveWarehouse.getTablePath(currentDatabase, tableName).toString } @@ -176,28 +170,49 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Nil } - val relation = MetastoreRelation( - databaseName, tblName, alias)( - table.getTTable, partitions.map(part => part.getTPartition))(hive) - - if (hive.convertMetastoreParquet && - hive.conf.parquetUseDataSourceApi && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet")) { - val metastoreSchema = StructType.fromAttributes(relation.output) - val paths = if (relation.hiveQlTable.isPartitioned) { - relation.hiveQlPartitions.map(p => p.getLocation) - } else { - Seq(relation.hiveQlTable.getDataLocation.toString) - } + MetastoreRelation(databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) + } + } - LogicalRelation(ParquetRelation2( - paths, Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) - } else { - relation + private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // serialize the Metastore schema to JSON and pass it as a data source option because of the + // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + if (metastoreRelation.hiveQlTable.isPartitioned) { + val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) + val partitionColumnDataTypes = partitionSchema.map(_.dataType) + val partitions = metastoreRelation.hiveQlPartitions.map { p => + val location = p.getLocation + val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) + }) + ParquetPartition(values, location) } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + val paths = partitions.map(_.path) + LogicalRelation( + ParquetRelation2( + paths, + Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json), + None, + Some(partitionSpec))(hive)) + } else { + val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + LogicalRelation( + ParquetRelation2( + paths, + Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) } } + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + val dbName = databaseName.getOrElse(hive.sessionState.getCurrentDatabase) + client.getAllTables(dbName).map(tableName => (tableName, false)) + } + /** * Create table with specified database, table name, table description and schema * @param databaseName Database Name @@ -256,9 +271,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName") tbl.setSerializationLib(classOf[LazySimpleSerDe].getName()) - import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.io.Text + import org.apache.hadoop.mapred.TextInputFormat tbl.setInputFormatClass(classOf[TextInputFormat]) tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]]) @@ -380,13 +395,56 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + /** + * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet + * data source relations for better performance. + * + * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. + */ + object ParquetConversions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Collects all `MetastoreRelation`s which should be replaced + val toBeReplaced = plan.collect { + // Write path + case InsertIntoTable(relation: MetastoreRelation, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + relation + + // Read path + case p @ PhysicalOperation(_, _, relation: MetastoreRelation) + if hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + relation + } + + // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes + // attribute IDs referenced in other nodes. + toBeReplaced.distinct.foldLeft(plan) { (lastPlan, relation) => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = AttributeMap(relation.output.zip(parquetRelation.output)) + + lastPlan.transformUp { + case r: MetastoreRelation if r == relation => parquetRelation + case other => other.transformExpressions { + case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) + } + } + } + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. */ object CreateTables extends Rule[LogicalPlan] { import org.apache.hadoop.hive.ql.Context - import org.apache.hadoop.hive.ql.parse.{QB, ASTNode, SemanticAnalyzer} + import org.apache.hadoop.hive.ql.parse.{ASTNode, QB, SemanticAnalyzer} def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Wait until children are resolved. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f3c9e63652a8e..5269460e5b6bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1099,7 +1099,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Cast(nodeToExpr(arg), DateType) /* Arithmetic */ - case Token("+", child :: Nil) => Add(Literal(0), nodeToExpr(child)) + case Token("+", child :: Nil) => nodeToExpr(child) case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cb138be90e2e1..965d159656d80 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -139,15 +139,19 @@ private[hive] trait HiveStrategies { val partitionLocations = partitions.map(_.getLocation) - hiveContext - .parquetFile(partitionLocations.head, partitionLocations.tail: _*) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil + if (partitionLocations.isEmpty) { + PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + } else { + hiveContext + .parquetFile(partitionLocations.head, partitionLocations.tail: _*) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection: _*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } } else { hiveContext .parquetFile(relation.hiveQlTable.getDataLocation.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index f6bea1c6a6fe1..0aa5f7f7b88bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand @@ -175,7 +175,7 @@ case class CreateMetastoreDataSourceAsSelect( val resolved = ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateAnalysisOperators(sqlContext.table(tableName).logicalPlan) match { + EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { case l @ LogicalRelation(i: InsertableRelation) => if (l.schema != createdRelation.schema) { val errorDescription = diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 313e84756b6bb..53ddecf57958b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.spark.sql.sources.SaveMode; +import org.apache.spark.sql.SaveMode; import org.junit.After; import org.junit.Assert; import org.junit.Before; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 89b18c3439cf6..9fcb04ca23590 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -37,7 +37,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ val testData = TestHive.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))) + (1 to 100).map(i => TestData(i, i.toString))).toDF before { // Since every we are doing tests for DDL statements, @@ -56,7 +56,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { // Make sure the table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) // Add more data. @@ -65,7 +65,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq + testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq ) // Now overwrite. @@ -74,7 +74,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { // Make sure the registered table has also been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq.map(Row.fromTuple) + testData.collect().toSeq ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala new file mode 100644 index 0000000000000..321b784a3f842 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -0,0 +1,77 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.Row + +class ListTablesSuite extends QueryTest with BeforeAndAfterAll { + + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + val df = + sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + + override def beforeAll(): Unit = { + // The catalog in HiveContext is a case insensitive one. + catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) + catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) + sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") + sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") + sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") + } + + override def afterAll(): Unit = { + catalog.unregisterTable(Seq("ListTablesSuiteTable")) + catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) + sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + } + + test("get all tables of current database") { + val allTables = tables() + // We are using default DB. + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hivelisttablessuitetable'"), + Row("hivelisttablessuitetable", false)) + assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0) + } + + test("getting all tables with a database name") { + val allTables = tables("ListTablesSuiteDB") + checkAnswer( + allTables.filter("tableName = 'listtablessuitetable'"), + Row("listtablessuitetable", true)) + checkAnswer( + allTables.filter("tableName = 'indblisttablessuitetable'"), + Row("indblisttablessuitetable", true)) + assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) + checkAnswer( + allTables.filter("tableName = 'hiveindblisttablessuitetable'"), + Row("hiveindblisttablessuitetable", false)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f94aabd29ad23..addf887ab9162 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.hive import java.io.File - -import org.apache.spark.sql.sources.SaveMode import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils @@ -30,17 +28,14 @@ import org.apache.spark.sql.catalyst.util import org.apache.spark.sql._ import org.apache.spark.util.Utils import org.apache.spark.sql.types._ - -/* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ /** * Tests for persisting tables created though the data sources API into the metastore. */ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - override def afterEach(): Unit = { reset() if (tempPath.exists()) Utils.deleteRecursively(tempPath) @@ -156,7 +151,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { test("check change without refresh") { val tempDir = File.createTempFile("sparksql", "json") tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b") :: Nil).toDF + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql( s""" @@ -172,7 +168,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row("a", "b")) FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) // Schema is cached so the new column does not show. The updated values in existing columns // will show. @@ -192,7 +189,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { test("drop, change, recreate") { val tempDir = File.createTempFile("sparksql", "json") tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b") :: Nil).toDF + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql( s""" @@ -208,7 +206,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row("a", "b")) FileUtils.deleteDirectory(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + sparkContext.parallelize(("a", "b", "c") :: Nil).toDF + .toJSON.saveAsTextFile(tempDir.getCanonicalPath) sql("DROP TABLE jsonTable") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 405b200d05412..d01dbf80ef66d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -567,7 +567,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerTempTable("REGisteredTABle") + testData.toDF.registerTempTable("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -592,7 +592,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") + TestHive.sparkContext.parallelize(fixture).toDF.registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -740,7 +740,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerTempTable("test_describe_commands2") + testData.toDF.registerTempTable("test_describe_commands2") assertResult( Array( @@ -900,8 +900,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") - sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") + sparkContext.makeRDD(Seq.empty[LogEntry]).toDF.registerTempTable("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).toDF.registerTempTable("logFiles") sql( """ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 029c36aa89b26..6fc4cc14265ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + .toDF.registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), @@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("caseSensitivityTest") + .toDF.registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerTempTable("nestedRepeatedTest") + .toDF.registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 8fb5e050a237a..ab53c6309e089 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 1e99003d3e9b5..245161d2ebbca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -111,7 +111,7 @@ class HiveUdfSuite extends QueryTest { test("UDFIntegerToString") { val testData = TestHive.sparkContext.parallelize( - IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF testData.registerTempTable("integerTable") sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") @@ -127,7 +127,7 @@ class HiveUdfSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF testData.registerTempTable("listListIntTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") @@ -142,7 +142,7 @@ class HiveUdfSuite extends QueryTest { test("UDFListString") { val testData = TestHive.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: - ListStringCaseClass(Seq("d", "e")) :: Nil) + ListStringCaseClass(Seq("d", "e")) :: Nil).toDF testData.registerTempTable("listStringTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") @@ -156,7 +156,7 @@ class HiveUdfSuite extends QueryTest { test("UDFStringString") { val testData = TestHive.sparkContext.parallelize( - StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF testData.registerTempTable("stringTable") sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") @@ -173,7 +173,7 @@ class HiveUdfSuite extends QueryTest { ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: - Nil) + Nil).toDF testData.registerTempTable("TwoListTable") sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9a6e8650a0ec4..978825938395f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} @@ -34,9 +35,6 @@ case class Nested3(f3: Int) */ class SQLQuerySuite extends QueryTest { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - val sqlCtx = TestHive - test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), @@ -176,7 +174,8 @@ class SQLQuerySuite extends QueryTest { } test("double nested data") { - sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) + .toDF().registerTempTable("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) @@ -199,7 +198,7 @@ class SQLQuerySuite extends QueryTest { } test("SPARK-4825 save join to table") { - val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") testData.insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") @@ -279,7 +278,7 @@ class SQLQuerySuite extends QueryTest { val rowRdd = sparkContext.parallelize(row :: Nil) - sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable") + TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index a7479a5b95864..2acf1a7767c19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -38,7 +40,7 @@ case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) * A suite to test the automatic conversion of metastore tables with parquet data to use the * built in parquet support. */ -class ParquetMetastoreSuite extends ParquetPartitioningTest { +class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -95,6 +97,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { + sql("DROP TABLE partitioned_parquet") + sql("DROP TABLE partitioned_parquet_with_key") + sql("DROP TABLE normal_parquet") setConf("spark.sql.hive.convertMetastoreParquet", "false") } @@ -111,10 +116,38 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } +class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + /** * A suite of tests for the Parquet support through the data sources API. */ -class ParquetSourceSuite extends ParquetPartitioningTest { +class ParquetSourceSuiteBase extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -144,6 +177,34 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } } +class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + /** * A collection of tests for parquet data with various forms of partitioning. */ @@ -152,7 +213,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll var normalTableDir: File = null var partitionedTableDirWithKey: File = null - import org.apache.spark.sql.hive.test.TestHive.implicits._ override def beforeAll(): Unit = { partitionedTableDir = File.createTempFile("parquettests", "sparksql") @@ -167,12 +227,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll val partDir = new File(partitionedTableDir, s"p=$p") sparkContext.makeRDD(1 to 10) .map(i => ParquetData(i, s"part-$p")) + .toDF() .saveAsParquetFile(partDir.getCanonicalPath) } sparkContext .makeRDD(1 to 10) .map(i => ParquetData(i, s"part-1")) + .toDF() .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") @@ -183,111 +245,104 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll val partDir = new File(partitionedTableDirWithKey, s"p=$p") sparkContext.makeRDD(1 to 10) .map(i => ParquetDataWithKey(p, i, s"part-$p")) + .toDF() .saveAsParquetFile(partDir.getCanonicalPath) } } - def run(prefix: String): Unit = { - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"$prefix: ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"$prefix: project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"$prefix: project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"$prefix: simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) - } - - test(s"$prefix: pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - Row(10)) - } - - test(s"$prefix: non-existent partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"$prefix: multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"$prefix: non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"$prefix: sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"$prefix: hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } + Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) } - test(s"$prefix: $prefix: non-part select(*)") { + test(s"pruned count $table") { checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), Row(10)) } - } - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - run("Parquet data source enabled") + test(s"non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - run("Parquet data source disabled") + test(s"sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } + } + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + Row(10)) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b780282bdac37..f88a8a0151550 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -152,7 +152,7 @@ class CheckpointWriter( // Delete old checkpoint files val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) - if (allCheckpointFiles.size > 4) { + if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) fs.delete(file, true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 6d5b8fda76ab8..c1d3f7320f53c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( @@ -109,7 +109,13 @@ class ExecutorRunnable( } // Send the start request to the ContainerManager - nmClient.startContainer(container, ctx) + try { + nmClient.startContainer(container, ctx) + } catch { + case ex: Exception => + throw new SparkException(s"Exception while starting container ${container.getId}" + + s" on host $hostname", ex) + } } private def prepareCommand(